Bayesian inference for a logistic regression model (Part 6)

Part 6: Hamiltonian Monte Carlo (HMC)

Introduction

This is the sixth part in a series of posts on MCMC-based Bayesian inference for a logistic regression model. If you are new to this series, please go back to Part 1.

In the previous post we saw how to construct an MCMC algorithm utilising gradient information by considering a Langevin equation having our target distribution of interest as its equilibrium. This equation has a physical interpretation in terms of the stochastic dynamics of a particle in a potential equal to minus the log of the target density. It turns out that thinking about the deterministic dynamics of a particle in such a potential can lead to more efficient MCMC algorithms.

Hamiltonian dynamics

Hamiltonian dynamics is often presented as an extension of a fairly general version of Lagrangian dynamics. However, for our purposes a rather simple version is quite sufficient, based on basic concepts from Newtonian dynamics, familiar from school. Inspired by our Langevin example, we will consider the dynamics of a particle in a potential function V(q). We will see later why we want V(q) = -\log \pi(q) for our target of interest, \pi(\cdot). In the context of Hamiltonian (and Lagrangian) dynamics we typically use q as our position variable, rather than x.

The potential function induces a (conservative) force on the particle equal to -\nabla V(q) when the particle is at position q. Then Newton’s second law of motion, "F=ma", takes the form

\displaystyle \nabla V(q) + m \ddot{q} = 0.

In Newtonian mechanics, we often consider the position vector q as 3-dimensional. Here it will be n-dimensional, where n is the number of variables in our target. We can then think of our second law as governing a single n-dimensional particle of mass m, or n one-dimensional particles all of mass m. But in this latter case, there is no need to assume that all particles have the same mass, and we could instead write our law of motion as

\displaystyle \nabla V(q) + M \ddot{q} = 0,

where M is a diagonal matrix. But in fact, since we could change coordinates, there’s no reason to require that M is diagonal. All we need is that M is positive definite, so that we don’t have negative mass in any coordinate direction.

We will take the above equation as the fundamental law governing our dynamical system of interest. The motivation from Newtonian dynamics is interesting, but not required. What is important is that the dynamics of such a system are conservative, in a way that we will shortly make precise.

Our law of motion is a second-order differential equation, since it involves the second derivative of q wrt time. If you’ve ever studied differential equations, you’ll know that there is an easy way to turn a second order equation into a first order equation with twice the dimension by augmenting the system with the velocities. Here, it is more convenient to augment the system with "momentum" variables, p, which we define as p = M\dot{q}. Then we can write our second order system as a pair of first order equations

\displaystyle \dot{q} = M^{-1}p

\displaystyle \dot{p} = -\nabla V(q)

These are, in fact, Hamilton’s equations for this system, though this isn’t how they are typically written.

If we define the kinetic energy as

\displaystyle T(p) = \frac{1}{2}p^\text{T}M^{-1}p,

then the Hamiltonian

\displaystyle H(q,p) = V(q) + T(p),

representing the total energy in the system, is conserved, since

\displaystyle \dot{H} = \nabla V\cdot \dot{q} + \dot{p}^\text{T}M^{-1}p = \nabla V\cdot \dot{q} + \dot{p}^\text{T}\dot{q} = [\nabla V + \dot{p}]\cdot\dot{q} = 0.

So, if we obey our Hamiltonian dynamics, our trajectory in (q,p)-space will follow contours of the Hamiltonian. It’s also clear that the system is time-reversible, so flipping the sign of the momentum p and integrating will exactly reverse the direction in which the contours are traversed. Another quite important property of Hamiltonian dynamics is that they are volume preserving. This can be verified by checking that the divergence of the flow is zero.

\displaystyle \nabla\cdot(\dot{q},\dot{p}) = \nabla_q\cdot\dot{q} + \nabla_p\cdot\dot{p} = 0,

since \dot{q} is a function of p only and \dot{p} is a function of q only.

Hamiltonian Monte Carlo (HMC)

In Hamiltonian Monte Carlo we introduce an augmented target distribution,

\displaystyle \tilde \pi(q,p) \propto \exp[-H(q,p)]

It is clear from this definition that moves leaving the Hamiltonian invariant will also leave the augmented target density unchanged. By following the Hamiltonian dynamics, we will be able to make big (reversible) moves in the space that will be accepted with high probability. Also, our target factorises into two independent components as

\displaystyle \tilde \pi(q,p) \propto \exp[-V(q)]\exp[-T(p)],

and so choosing V(q)=-\log \pi(q) will ensure that the q-marginal is our real target of interest, \pi(\cdot). It’s also clear that our p-marginal is \mathcal N(0,M). This is also the full-conditional for p, so re-sampling p from this distribution and leaving q unchanged is a Gibbs move that will leave the augmented target invariant. Re-sampling p will be necessary to properly explore our augmented target, since this will move us to a different contour of H.

So, an idealised version of HMC would proceed as follows: First, update p by sampling from its known tractable marginal. Second, update p and q jointly by following the Hamiltonian dynamics. If this second move is regarded as a (deterministic) reversible M-H proposal, it will be accepted with probability one since it leaves the augmented target density unchanged. If we could exactly integrate Hamilton’s equations, this would be fine. But in practice, we will need to use some imperfect numerical method for the integration step. But just as for MALA, we can regard the numerical method as a M-H proposal and correct for the fact that it is imperfect, preserving the exact augmented target distribution.

Hamiltonian systems admit nice numerical integration schemes called symplectic integrators. In HMC a simple alternating Euler method is typically used, known as the leap-frog algorithm. The component updates are all shear transformations, and therefore volume preserving, and exact reversibility is ensured by starting and ending with a half-step update of the momentum variables. In principle, to ensure reversibility of the proposal the momentum variables should be sign-flipped (reversed) to finish, but in practice this doesn’t matter since it doesn’t affect the evaluation of the Hamiltonian and it will then get refreshed, anyway.

So, advancing our system by a time step \epsilon can be done with

\displaystyle p(t+\epsilon/2) := p(t) - \frac{\epsilon}{2}\nabla V(q(t))

\displaystyle q(t+\epsilon) := q(t) + \epsilon M^{-1}p(t+\epsilon/2)

\displaystyle p(t+\epsilon) := p(t+\epsilon/2) - \frac{\epsilon}{2}\nabla V(q(t+\epsilon))

It is clear that if many such updates are chained together, adjacent momentum updates can be collapsed together, giving rise to the "leap-frog" nature of the algorithm, and therefore requiring roughly one gradient evaluation per \epsilon update, rather than two. Since this integrator is volume preserving and exactly reversible, for reasonably small \epsilon it follows the Hamiltonian dynamics reasonably well, but not exactly, and so it does not exactly preserve the Hamiltonian. However, it does make a good M-H proposal, and reasonable acceptance probabilities can often be obtained by chaining together l updates to advance the time of the system by T=l\epsilon. The "optimal" value of l and \epsilon will be highly problem dependent, but values of l=20 or l=50 are not unusual. There are various more-or-less standard methods for tuning these, but we will not consider them here.

Note that since our HMC update on the augmented space consists of a Gibbs move and a M-H update, it is important that our M-H kernel does not keep or thread through the old log target density from the previous M-H update, since the Gibbs move will have changed it in the meantime.

Implementations

R

We need a M-H kernel that does not thread through the old log density.

mhKernel = function(logPost, rprop)
    function(x) {
        prop = rprop(x)
        a = logPost(prop) - logPost(x)
        if (log(runif(1)) < a)
            prop
        else
            x
    }

We can then use this to construct a M-H move as part of our HMC update.

hmcKernel = function(lpi, glpi, eps = 1e-4, l=10, dmm = 1) {
    sdmm = sqrt(dmm)
    leapf = function(q, p) {
        p = p + 0.5*eps*glpi(q)
        for (i in 1:l) {
            q = q + eps*p/dmm
            if (i < l)
                p = p + eps*glpi(q)
            else
                p = p + 0.5*eps*glpi(q)
        }
        list(q=q, p=-p)
    }
    alpi = function(x)
        lpi(x$q) - 0.5*sum((x$p^2)/dmm)
    rprop = function(x)
        leapf(x$q, x$p)
    mhk = mhKernel(alpi, rprop)
    function(q) {
        d = length(q)
        x = list(q=q, p=rnorm(d, 0, sdmm))
        mhk(x)$q
    }
}

See the full runnable script for further details.

Python

First a M-H kernel,

def mhKernel(lpost, rprop):
    def kernel(x):
        prop = rprop(x)
        a = lpost(prop) - lpost(x)
        if (np.log(np.random.rand()) < a):
            x = prop
        return x
    return kernel

and then an HMC kernel.

def hmcKernel(lpi, glpi, eps = 1e-4, l=10, dmm = 1):
    sdmm = np.sqrt(dmm)
    def leapf(q, p):    
        p = p + 0.5*eps*glpi(q)
        for i in range(l):
            q = q + eps*p/dmm
            if (i < l-1):
                p = p + eps*glpi(q)
            else:
                p = p + 0.5*eps*glpi(q)
        return (q, -p)
    def alpi(x):
        (q, p) = x
        return lpi(q) - 0.5*np.sum((p**2)/dmm)
    def rprop(x):
        (q, p) = x
        return leapf(q, p)
    mhk = mhKernel(alpi, rprop)
    def kern(q):
        d = len(q)
        p = np.random.randn(d)*sdmm
        return mhk((q, p))[0]
    return kern

See the full runnable script for further details.

JAX

Again, we want an appropriate M-H kernel,

def mhKernel(lpost, rprop, dprop = jit(lambda new, old: 1.)):
    @jit
    def kernel(key, x):
        key0, key1 = jax.random.split(key)
        prop = rprop(key0, x)
        ll = lpost(x)
        lp = lpost(prop)
        a = lp - ll + dprop(x, prop) - dprop(prop, x)
        accept = (jnp.log(jax.random.uniform(key1)) < a)
        return jnp.where(accept, prop, x)
    return kernel

and then an HMC kernel.

def hmcKernel(lpi, glpi, eps = 1e-4, l = 10, dmm = 1):
    sdmm = jnp.sqrt(dmm)
    @jit
    def leapf(q, p):    
        p = p + 0.5*eps*glpi(q)
        for i in range(l):
            q = q + eps*p/dmm
            if (i < l-1):
                p = p + eps*glpi(q)
            else:
                p = p + 0.5*eps*glpi(q)
        return jnp.concatenate((q, -p))
    @jit
    def alpi(x):
        d = len(x) // 2
        return lpi(x[jnp.array(range(d))]) - 0.5*jnp.sum((x[jnp.array(range(d,2*d))]**2)/dmm)
    @jit
    def rprop(k, x):
        d = len(x) // 2
        return leapf(x[jnp.array(range(d))], x[jnp.array(range(d, 2*d))])
    mhk = mhKernel(alpi, rprop)
    @jit
    def kern(k, q):
        key0, key1 = jax.random.split(k)
        d = len(q)
        x = jnp.concatenate((q, jax.random.normal(key0, [d])*sdmm))
        return mhk(key1, x)[jnp.array(range(d))]
    return kern

There is something a little bit strange about this implementation, since the proposal for the M-H move is deterministic, the function rprop just ignores the RNG key that is passed to it. We could tidy this up by making a M-H function especially for deterministic proposals. We won’t pursue this here, but this issue will crop up again later in some of the other functional languages.

See the full runnable script for further details.

Scala

A M-H kernel,

def mhKern[S](
    logPost: S => Double, rprop: S => S,
    dprop: (S, S) => Double = (n: S, o: S) => 1.0
  ): (S) => S =
    val r = Uniform(0.0,1.0)
    x0 =>
      val x = rprop(x0)
      val ll0 = logPost(x0)
      val ll = logPost(x)
      val a = ll - ll0 + dprop(x0, x) - dprop(x, x0)
      if (math.log(r.draw()) < a) x else x0

and a HMC kernel.

def hmcKernel(lpi: DVD => Double, glpi: DVD => DVD, dmm: DVD,
  eps: Double = 1e-4, l: Int = 10) =
  val sdmm = sqrt(dmm)
  def leapf(q: DVD, p: DVD): (DVD, DVD) = 
    @tailrec def go(q0: DVD, p0: DVD, l: Int): (DVD, DVD) =
      val q = q0 + eps*(p0/:/dmm)
      val p = if (l > 1)
        p0 + eps*glpi(q)
      else
        p0 + 0.5*eps*glpi(q)
      if (l == 1)
        (q, -p)
      else
        go(q, p, l-1)
    go(q, p + 0.5*eps*glpi(q), l)
  def alpi(x: (DVD, DVD)): Double =
    val (q, p) = x
    lpi(q) - 0.5*sum(pow(p,2) /:/ dmm)
  def rprop(x: (DVD, DVD)): (DVD, DVD) =
    val (q, p) = x
    leapf(q, p)
  val mhk = mhKern(alpi, rprop)
  (q: DVD) =>
    val d = q.length
    val p = sdmm map (sd => Gaussian(0,sd).draw())
    mhk((q, p))._1

See the full runnable script for further details.

Haskell

A M-H kernel:

mdKernel :: (StatefulGen g m) => (s -> Double) -> (s -> s) -> g -> s -> m s
mdKernel logPost prop g x0 = do
  let x = prop x0
  let ll0 = logPost x0
  let ll = logPost x
  let a = ll - ll0
  u <- (genContVar (uniformDistr 0.0 1.0)) g
  let next = if ((log u) < a)
        then x
        else x0
  return next

Note that here we are using a M-H kernel specifically for deterministic proposals, since there is no non-determinism signalled in the type signature of prop. We can then use this to construct our HMC kernel.

hmcKernel :: (StatefulGen g m) =>
  (Vector Double -> Double) -> (Vector Double -> Vector Double) -> Vector Double ->
  Double -> Int -> g ->
  Vector Double -> m (Vector Double)
hmcKernel lpi glpi dmm eps l g = let
  sdmm = cmap sqrt dmm
  leapf q p = let
    go q0 p0 l = let
      q = q0 + (scalar eps)*p0/dmm
      p = if (l > 1)
        then p0 + (scalar eps)*(glpi q)
        else p0 + (scalar (eps/2))*(glpi q)
      in if (l == 1)
      then (q, -p)
      else go q p (l - 1)
    in go q (p + (scalar (eps/2))*(glpi q)) l
  alpi x = let
    (q, p) = x
    in (lpi q) - 0.5*(sumElements (p*p/dmm))
  prop x = let
    (q, p) = x
    in leapf q p
  mk = mdKernel alpi prop g
  in (\q0 -> do
         let d = size q0
         zl <- (replicateM d . genContVar (normalDistr 0.0 1.0)) g
         let z = fromList zl
         let p0 = sdmm * z
         (q, p) <- mk (q0, p0)
         return q)

See the full runnable script for further details.

Dex

Again we can use a M-H kernel specific to deterministic proposals.

def mdKernel {s} (lpost: s -> Float) (prop: s -> s)
    (x0: s) (k: Key) : s =
  x = prop x0
  ll0 = lpost x0
  ll = lpost x
  a = ll - ll0
  u = rand k
  select (log u < a) x x0

and use this to construct an HMC kernel.

def hmcKernel {n} (lpi: (Fin n)=>Float -> Float)
    (dmm: (Fin n)=>Float) (eps: Float) (l: Nat)
    (q0: (Fin n)=>Float) (k: Key) : (Fin n)=>Float =
  sdmm = sqrt dmm
  idmm = map (\x. 1.0/x) dmm
  glpi = grad lpi
  def leapf (q0: (Fin n)=>Float) (p0: (Fin n)=>Float) :
      ((Fin n)=>Float & (Fin n)=>Float) =
    p1 = p0 + (eps/2) .* (glpi q0)
    q1 = q0 + eps .* (p1*idmm)
    (q, p) = apply_n l (q1, p1) \(qo, po).
      pn = po + eps .* (glpi qo)
      qn = qo + eps .* (pn*idmm)
      (qn, pn)
    pf = p + (eps/2) .* (glpi q)
    (q, -pf)
  def alpi (qp: ((Fin n)=>Float & (Fin n)=>Float)) : Float =
    (q, p) = qp
    (lpi q) - 0.5*(sum (p*p*idmm))
  def prop (qp: ((Fin n)=>Float & (Fin n)=>Float)) :
      ((Fin n)=>Float & (Fin n)=>Float) =
    (q, p) = qp
    leapf q p
  mk = mdKernel alpi prop
  [k1, k2] = split_key k
  z = randn_vec k1
  p0 = sdmm * z
  (q, p) = mk (q0, p0) k2
  q

Note that the gradient is obtained via automatic differentiation. See the full runnable script for details.

Next steps

This was the main place that I was trying to get to when I started this series of posts. For differentiable log-posteriors (as we have in the case of Bayesian logistic regression), HMC is a pretty good algorithm for reasonably efficient posterior exploration. But there are lots of places we could go from here. We could explore the tuning of MCMC algorithms, or HMC extensions such as NUTS. We could look at MCMC algorithms that are specifically tailored to the logistic regression problem, or we could look at new MCMC algorithms for differentiable targets based on piecewise deterministic Markov processes. Alternatively, we could temporarily abandon MCMC and look at SMC or ABC approaches. Another possibility would be to abandon this multi-language approach and have a bit of a deep dive into Dex, which I think has the potential to be a great programming language for statistical computing. All of these are possibilities for the future, but I’ve a busy few weeks coming up, so the frequency of these posts is likely to substantially decrease.

Remember that all of the code associated with this series of posts is available from this github repo.

Advertisement

Bayesian inference for a logistic regression model (Part 5)

Part 5: the Metropolis-adjusted Langevin algorithm (MALA)

Introduction

This is the fifth part in a series of posts on MCMC-based Bayesian inference for a logistic regression model. If you are new to this series, please go back to Part 1.

In the previous post we saw how to use Langevin dynamics to construct an approximate MCMC scheme using the gradient of the log target distribution. Each step of the algorithm involved simulating from the Euler-Maruyama approximation to the transition kernel of the process, based on some pre-specified step size, \Delta t. We can improve the accuracy of this approximation by making the step size smaller, but this will come at the expense of a more slowly mixing MCMC chain.

Fortunately, there is an easy way to make the algorithm "exact" (in the sense that the equilibrium distribution of the Markov chain will be the exact target distribution), for any finite step size, \Delta t, simply by using the Euler-Maruyama approximation as the proposal distribution in a Metropolis-Hastings algorithm. This is the Metropolis-adjusted Langevin algorithm (MALA). There are various ways this could be coded up, but here, for clarity, a HoF for generating a MALA kernel will be used, and this function will in turn call on a HoF for generating a Metropolis-Hastings kernel.

Implementations

R

First we need a function to generate a M-H kernel.

mhKernel = function(logPost, rprop, dprop = function(new, old, ...) { 1 })
    function(x, ll) {
        prop = rprop(x)
        llprop = logPost(prop)
        a = llprop - ll + dprop(x, prop) - dprop(prop, x)
        if (log(runif(1)) < a)
            list(x=prop, ll=llprop)
        else
            list(x=x, ll=ll)
    }

Then we can easily write a function for returning a MALA kernel that makes use of this M-H function.

malaKernel = function(lpi, glpi, dt = 1e-4, pre = 1) {
    sdt = sqrt(dt)
    spre = sqrt(pre)
    advance = function(x) x + 0.5*pre*glpi(x)*dt
    mhKernel(lpi, function(x) rnorm(p, advance(x), spre*sdt),
             function(new, old) sum(dnorm(new, advance(old), spre*sdt, log=TRUE)))
}

Notice that our MALA function requires as input both the gradient of the log posterior (for the proposal) and the log posterior itself (for the M-H correction). Other details are as we have already seen – see the full runnable script.

Python

Again, we need a M-H kernel

def mhKernel(lpost, rprop, dprop = lambda new, old: 1.):
    def kernel(x, ll):
        prop = rprop(x)
        lp = lpost(prop)
        a = lp - ll + dprop(x, prop) - dprop(prop, x)
        if (np.log(np.random.rand()) < a):
            x = prop
            ll = lp
        return x, ll
    return kernel

and then a MALA kernel

def malaKernel(lpi, glpi, dt = 1e-4, pre = 1):
    p = len(init)
    sdt = np.sqrt(dt)
    spre = np.sqrt(pre)
    advance = lambda x: x + 0.5*pre*glpi(x)*dt
    return mhKernel(lpi, lambda x: advance(x) + np.random.randn(p)*spre*sdt,
            lambda new, old: np.sum(sp.stats.norm.logpdf(new, loc=advance(old), scale=spre*sdt)))

See the full runnable script for further details.

JAX

If we want our algorithm to run fast, and if we want to exploit automatic differentiation to avoid the need to manually compute gradients, then we can easily convert the above code to use JAX.

def mhKernel(lpost, rprop, dprop = jit(lambda new, old: 1.)):
    @jit
    def kernel(key, x, ll):
        key0, key1 = jax.random.split(key)
        prop = rprop(key0, x)
        lp = lpost(prop)
        a = lp - ll + dprop(x, prop) - dprop(prop, x)
        accept = (jnp.log(jax.random.uniform(key1)) < a)
        return jnp.where(accept, prop, x), jnp.where(accept, lp, ll)
    return kernel

def malaKernel(lpi, dt = 1e-4, pre = 1):
    p = len(init)
    glpi = jit(grad(lpost))
    sdt = jnp.sqrt(dt)
    spre = jnp.sqrt(pre)
    advance = jit(lambda x: x + 0.5*pre*glpi(x)*dt)
    return mhKernel(lpi, jit(lambda k, x: advance(x) +
                             jax.random.normal(k, [p])*spre*sdt),
            jit(lambda new, old:
                jnp.sum(jsp.stats.norm.logpdf(new,
                      loc=advance(old), scale=spre*sdt))))

See the full runnable script for further details.

Scala

def mhKernel[S](
    logPost: S => Double, rprop: S => S,
    dprop: (S, S) => Double = (n: S, o: S) => 1.0
  ): ((S, Double)) => (S, Double) =
    val r = Uniform(0.0,1.0)
    state =>
      val (x0, ll0) = state
      val x = rprop(x0)
      val ll = logPost(x)
      val a = ll - ll0 + dprop(x0, x) - dprop(x, x0)
      if (math.log(r.draw()) < a)
        (x, ll)
      else
        (x0, ll0)

def malaKernel(lpi: DVD => Double, glpi: DVD => DVD, pre: DVD, dt: Double = 1e-4) =
  val sdt = math.sqrt(dt)
  val spre = sqrt(pre)
  val p = pre.length
  def advance(beta: DVD): DVD =
    beta + (0.5*dt)*(pre*:*glpi(beta))
  def rprop(beta: DVD): DVD =
    advance(beta) + sdt*spre.map(Gaussian(0,_).sample())
  def dprop(n: DVD, o: DVD): Double = 
    val ao = advance(o)
    (0 until p).map(i => Gaussian(ao(i), spre(i)*sdt).logPdf(n(i))).sum
  mhKernel(lpi, rprop, dprop)

See the full runnable script for further details.

Haskell

mhKernel :: (StatefulGen g m) => (s -> Double) -> (s -> g -> m s) ->
  (s -> s -> Double) -> g -> (s, Double) -> m (s, Double)
mhKernel logPost rprop dprop g (x0, ll0) = do
  x <- rprop x0 g
  let ll = logPost(x)
  let a = ll - ll0 + (dprop x0 x) - (dprop x x0)
  u <- (genContVar (uniformDistr 0.0 1.0)) g
  let next = if ((log u) < a)
        then (x, ll)
        else (x0, ll0)
  return next

malaKernel :: (StatefulGen g m) =>
  (Vector Double -> Double) -> (Vector Double -> Vector Double) -> 
  Vector Double -> Double -> g ->
  (Vector Double, Double) -> m (Vector Double, Double)
malaKernel lpi glpi pre dt g = let
  sdt = sqrt dt
  spre = cmap sqrt pre
  p = size pre
  advance beta = beta + (scalar (0.5*dt))*pre*(glpi beta)
  rprop beta g = do
    zl <- (replicateM p . genContVar (normalDistr 0.0 1.0)) g
    let z = fromList zl
    return $ advance(beta) + (scalar sdt)*spre*z
  dprop n o = let
    ao = advance o
    in sum $ (\i -> logDensity (normalDistr (ao!i) 
      ((spre!i)*sdt)) (n!i)) <$> [0..(p-1)]
  in mhKernel lpi rprop dprop g

See the full runnable script for further details.

Dex

Recall that Dex is differentiable, so we don’t need to provide gradients.

def mhKernel {s} (lpost: s -> Float) (rprop: s -> Key -> s) (dprop: s -> s -> Float)
    (sll: (s & Float)) (k: Key) : (s & Float) =
  (x0, ll0) = sll
  [k1, k2] = split_key k
  x = rprop x0 k1
  ll = lpost x
  a = ll - ll0 + (dprop x0 x) - (dprop x x0)
  u = rand k2
  select (log u < a) (x, ll) (x0, ll0)

def malaKernel {n} (lpi: (Fin n)=>Float -> Float)
    (pre: (Fin n)=>Float) (dt: Float) :
    ((Fin n)=>Float & Float) -> Key -> ((Fin n)=>Float & Float) =
  sdt = sqrt dt
  spre = sqrt pre
  glp = grad lpi
  v = dt .* pre
  vinv = map (\ x. 1.0/x) v
  def advance (beta: (Fin n)=>Float) : (Fin n)=>Float =
    beta + (0.5*dt) .* (pre*(glp beta))
  def rprop (beta: (Fin n)=>Float) (k: Key) : (Fin n)=>Float =
    (advance beta) + sdt .* (spre*(randn_vec k))
  def dprop (new: (Fin n)=>Float) (old: (Fin n)=>Float) : Float =
    ao = advance old
    diff = new - ao
    -0.5 * sum ((log v) + diff*diff*vinv)
  mhKernel lpi rprop dprop

See the full runnable script for further details.

Next steps

MALA gives us an MCMC algorithm that exploits gradient information to generate "informed" M-H proposals. But it still has a rather "diffusive" character, making it difficult to tune in such a way that large moves are likely to be accepted in challenging high-dimensional situations.

The Langevin dynamics on which MALA is based can be interpreted as the (over-damped) stochastic dynamics of a particle moving in a potential energy field corresponding to minus the log posterior. It turns out that the corresponding deterministic dynamics can be exploited to generate proposals better able to make large moves while still having a high probability of acceptance. This is the idea behind Hamiltonian Monte Carlo (HMC), which we’ll look at next.

Bayesian inference for a logistic regression model (Part 4)

Part 4: Gradients and the Langevin algorithm

Introduction

This is the fourth part in a series of posts on MCMC-based Bayesian inference for a logistic regression model. If you are new to this series, please go back to Part 1.

In the previous post we saw how the Metropolis algorithm could be used to generate a Markov chain targeting our posterior distribution. In high dimensions the diffusive nature of the Metropolis random walk proposal becomes increasingly inefficient. It is therefore natural to try and develop algorithms that use additional information about the target distribution. In the case of a differentiable log posterior target, a natural first step in this direction is to try and make use of gradient information.

Gradient of a logistic regression model

There are various ways to derive the gradient of our logistic regression model, but it might be simplest to start from the first form of the log likelihood that we deduced in Part 2:

\displaystyle l(\beta;y) = y^\textsf{T}X\beta - \mathbf{1}^\textsf{T}\log(\mathbf{1}+\exp[X\beta])

We can write this out in component form as

\displaystyle l(\beta;y) = \sum_j\sum_j y_iX_{ij}\beta_j - \sum_i\log\left(1+\exp\left[\sum_jX_{ij}\beta_j\right]\right).

Differentiating wrt \beta_k gives

\displaystyle \frac{\partial l}{\partial \beta_k} = \sum_i y_iX_{ik} - \sum_i \frac{\exp\left[\sum_j X_{ij}\beta_j\right]X_{ik}}{1+\exp\left[\sum_j X_{ij}\beta_j\right]}.

It’s then reasonably clear that stitching all of the partial derivatives together will give the gradient vector

\displaystyle \nabla l = X^\textsf{T}\left[ y - \frac{\mathbf{1}}{\mathbf{1}+\exp[-X\beta]} \right].

This is the gradient of the log likelihood, but we also need the gradient of the log prior. Since we are assuming independent \beta_i \sim N(0,v_i) priors, it is easy to see that the gradient of the log prior is just -\beta\circ v^{-1}. It is the sum of these two terms that gives the gradient of the log posterior.

R

In R we can implement our gradient function as

glp = function(beta) {
    glpr = -beta/(pscale*pscale)
    gll = as.vector(t(X) %*% (y - 1/(1 + exp(-X %*% beta))))
    glpr + gll
}

Python

In Python we could use

def glp(beta):
    glpr = -beta/(pscale*pscale)
    gll = (X.T).dot(y - 1/(1 + np.exp(-X.dot(beta))))
    return (glpr + gll)

We don’t really need a JAX version, since JAX can auto-diff the log posterior for us.

Scala

  def glp(beta: DVD): DVD =
    val glpr = -beta /:/ pvar
    val gll = (X.t)*(y - ones/:/(ones + exp(-X*beta)))
    glpr + gll

Haskell

Using hmatrix we could use something like

glp :: Matrix Double -> Vector Double -> Vector Double -> Vector Double
glp x y b = let
  glpr = -b / (fromList [100.0, 1, 1, 1, 1, 1, 1, 1])
  gll = (tr x) #> (y - (scalar 1)/((scalar 1) + (cmap exp (-x #> b))))
  in glpr + gll

There’s something interesting to say about Haskell and auto-diff, but getting into this now will be too much of a distraction. I may come back to it in some future post.

Dex

Dex is differentiable, so we don’t need a gradient function – we can just use grad lpost. However, for interest and comparison purposes we could nevertheless implement it directly with something like

prscale = map (\ x. 1.0/x) pscale

def glp (b: (Fin 8)=>Float) : (Fin 8)=>Float =
  glpr = -b*prscale*prscale
  gll = (transpose x) **. (y - (map (\eta. 1.0/(1.0 + eta)) (exp (-x **. b))))
  glpr + gll

Langevin diffusions

Now that we have a way of computing the gradient of the log of our target density we need some MCMC algorithms that can make good use of it. In this post we will look at a simple approximate MCMC algorithm derived from an overdamped Langevin diffusion model. In subsequent posts we’ll look at more sophisticated, exact MCMC algorithms.

The multivariate stochastic differential equation (SDE)

\displaystyle dX_t = \frac{1}{2}\nabla\log\pi(X_t)dt + dW_t

has \pi(\cdot) as its equilibrium distribution. Informally, an SDE of this form is a continuous time process with infinitesimal transition kernel

\displaystyle X_{t+dt}|(X_t=x_t) \sim N\left(x_t+\frac{1}{2}\nabla\log\pi(x_t)dt,\mathbf{I}dt\right).

There are various more-or-less formal ways to see that \pi(\cdot) is stationary. A good way is to check it satisfies the Fokker–Planck equation with zero LHS. A less formal approach would be to see that the infinitesimal transition kernel for the process satisfies detailed balance with \pi(\cdot).

Similar arguments show that for any fixed positive definite matrix A, the SDE

\displaystyle dX_t = \frac{1}{2}A\nabla\log\pi(X_t)dt + A^{1/2}dW_t

also has \pi(\cdot) as a stationary distribution. It is quite common to choose a diagonal matrix A to put the components of X_t on a common scale.

The unadjusted Langevin algorithm

Simulating exact sample paths from SDEs such as the overdamped Langevin diffusion model is typically difficult (though not necessarily impossible), so we instead want something simple and tractable as the basis of our MCMC algorithms. Here we will just simulate from the Euler–Maruyama approximation of the process by choosing a small but finite time step \Delta t and using the transition kernel

\displaystyle X_{t+\Delta t}|(X_t=x_t) \sim N\left(x_t+\frac{1}{2}A\nabla\log\pi(x_t)\Delta t, A\Delta t\right)

as the basis of our MCMC method. For sufficiently small \Delta t this should accurately approximate the Langevin dynamics, leading to an equilibrium distribution very close to \pi(\cdot). That said, we would like to choose \Delta t as large as we can get away with, since that will lead to a more rapidly mixing MCMC chain. Below are some implementations of this kernel for a diagonal pre-conditioning matrix.

Implementation

R

We can create a kernel for the unadjusted Langevin algorithm in R with the following function.

ulKernel = function(glpi, dt = 1e-4, pre = 1) {
    sdt = sqrt(dt)
    spre = sqrt(pre)
    advance = function(x) x + 0.5*pre*glpi(x)*dt
    function(x, ll) rnorm(p, advance(x), spre*sdt)
}

Here, we can pass in pre, which is expected to be a vector representing the diagonal of the pre-conditioning matrix, A. We can then use this kernel to generate an MCMC chain as we have seen previously. See the full runnable script for further details.

Python

def ulKernel(glpi, dt = 1e-4, pre = 1):
    p = len(init)
    sdt = np.sqrt(dt)
    spre = np.sqrt(pre)
    advance = lambda x: x + 0.5*pre*glpi(x)*dt
    def kernel(x):
        return advance(x) + np.random.randn(p)*spre*sdt
    return kernel

See the full runnable script for further details.

JAX

def ulKernel(lpi, dt = 1e-4, pre = 1):
    p = len(init)
    glpi = jit(grad(lpi))
    sdt = jnp.sqrt(dt)
    spre = jnp.sqrt(pre)
    advance = jit(lambda x: x + 0.5*pre*glpi(x)*dt)
    @jit
    def kernel(key, x):
        return advance(x) + jax.random.normal(key, [p])*spre*sdt
    return kernel

Note how for JAX we can just pass in the log posterior, and the gradient function can be obtained by automatic differentiation. See the full runnable script for further details.

Scala

def ulKernel(glp: DVD => DVD, pre: DVD, dt: Double): DVD => DVD =
  val sdt = math.sqrt(dt)
  val spre = sqrt(pre)
  def advance(beta: DVD): DVD =
    beta + (0.5*dt)*(pre*:*glp(beta))
  beta => advance(beta) + sdt*spre.map(Gaussian(0,_).sample())

See the full runnable script for further details.

Haskell

ulKernel :: (StatefulGen g m) =>
  (Vector Double -> Vector Double) -> Vector Double -> Double -> g ->
  Vector Double -> m (Vector Double)
ulKernel glpi pre dt g beta = do
  let sdt = sqrt dt
  let spre = cmap sqrt pre
  let p = size pre
  let advance beta = beta + (scalar (0.5*dt))*pre*(glpi beta)
  zl <- (replicateM p . genContVar (normalDistr 0.0 1.0)) g
  let z = fromList zl
  return $latex  advance(beta) + (scalar sdt)*spre*z

See the full runnable script for further details.

Dex

In Dex we can write a function that accepts a gradient function

def ulKernel {n} (glpi: (Fin n)=>Float -> (Fin n)=>Float)
    (pre: (Fin n)=>Float) (dt: Float)
    (b: (Fin n)=>Float) (k: Key) : (Fin n)=>Float =
  sdt = sqrt dt
  spre = sqrt pre
  b + (((0.5)*dt) .* (pre*(glpi b))) +
    (sdt .* (spre*(randn_vec k)))

or we can write a function that accepts a log posterior, and uses auto-diff to construct the gradient

def ulKernel {n} (lpi: (Fin n)=>Float -> Float)
    (pre: (Fin n)=>Float) (dt: Float)
    (b: (Fin n)=>Float) (k: Key) : (Fin n)=>Float =
  glpi = grad lpi
  sdt = sqrt dt
  spre = sqrt pre
  b + ((0.5)*dt) .* (pre*(glpi b)) +
    sdt .* (spre*(randn_vec k))

and since Dex is statically typed, we can’t easily mix these functions up.

See the full runnable scripts, without and with auto-diff.

Next steps

In this post we have seen how to construct an MCMC algorithm that makes use of gradient information. But this algorithm is approximate. In the next post we’ll see how to correct for the approximation by using the Langevin updates as proposals within a Metropolis-Hastings algorithm.

Bayesian inference for a logistic regression model (Part 3)

Part 3: The Metropolis algorithm

Introduction

This is the third part in a series of posts on MCMC-based Bayesian inference for a logistic regression model. If you are new to this series, please go back to Part 1.

In the previous post we derived the log posterior for the model and implemented it in a variety of programming languages and libraries. In this post we will construct a Markov chain having the posterior as its equilibrium.

MCMC

Detailed balance

A homogeneous Markov chain with transition kernel p(\theta_{n+1}|\theta_n) is said to satisfy detailed balance for some target distribution \pi(\theta) if

\displaystyle \pi(\theta)p(\theta'|\theta) = \pi(\theta')p(\theta|\theta'), \quad \forall \theta, \theta'

Integrating both sides wrt \theta gives

\displaystyle \int_\Theta \pi(\theta)p(\theta'|\theta)d\theta = \pi(\theta'),

from which it is clear that \pi(\cdot) is a stationary distribution of the chain (and the chain is reversible). Under fairly mild regularity conditions we expect \pi(\cdot) to be the equilibrium distribution of the chain.

For a given target \pi(\cdot) we would like to find an easy-to-sample-from transition kernel p(\cdot|\cdot) that satisfies detailed balance. This will then give us a way to (asymptotically) generate samples from our target.

In the context of Bayesian inference, the target \pi(\theta) will typically be the posterior distribution, which in the previous post we wrote as \pi(\theta|y). Here we drop the notational dependence on y, since MCMC can be used for any target distribution of interest.

Metropolis-Hastings

Suppose we have a fairly arbitrary easy-to-sample-from transition kernel q(\theta_{n+1}|\theta_n) and a target of interest, \pi(\cdot). Metropolis-Hastings (M-H) is a strategy for using q(\cdot|\cdot) to construct a new transition kernel p(\cdot|\cdot) satisfying detailed balance for \pi(\cdot).

The kernel p(\theta'|\theta) can be described algorithmically as follows:

  1. Call the current state of the chain \theta. Generate a proposal \theta^\star by simulating from q(\theta^\star|\theta).
  2. Compute the acceptance probability

\displaystyle \alpha(\theta^\star|\theta) = \min\left[1,\frac{\pi(\theta^\star)q(\theta|\theta^\star)}{\pi(\theta)q(\theta^\star|\theta)}\right].

  1. With probability \alpha(\theta^\star|\theta) return new state \theta'=\theta^\star, otherwise return \theta'=\theta.

It is clear from the algorithmic description that this kernel will have a point mass at \theta'=\theta, but that for \theta'\not=\theta the transition kernel will be p(\theta'|\theta)=q(\theta'|\theta)\alpha(\theta'|\theta). But then

\displaystyle \pi(\theta)p(\theta'|\theta) = \min[\pi(\theta)q(\theta'|\theta),\pi(\theta')q(\theta|\theta')]

is symmetric in \theta and \theta', and so detailed balance is satisfied. Since detailed balance is trivial at the point mass at \theta=\theta' we are done.

Metropolis algorithm

It is often convenient to generate proposals perturbatively, using a distribution that is symmetric about the current state of the chain. But then q(\theta'|\theta)=q(\theta|\theta'), and so q(\cdot|\cdot) drops out of the acceptance probability. This is the Metropolis algorithm.

Some computational tricks

To generate an event with probability \alpha(\theta^\star|\theta), we can generate a u\sim U(0,1) and accept if u < \alpha(\theta^\star|\theta). This is convenient for several reasons. First, it means that we can ignore the "min", and just accept if

\displaystyle u < \frac{\pi(\theta^\star)q(\theta|\theta^\star)}{\pi(\theta)q(\theta^\star|\theta)}

since u\leq 1 regardless. Better still, we can take logs, and accept if

\displaystyle \log u < \log\pi(\theta^\star) - \log\pi(\theta) + \log q(\theta|\theta^\star) - \log q(\theta^\star|\theta),

so there is no need to evaluate any raw densities. Again, in the case of a symmetric proposal distribution, the q(\cdot|\cdot) terms can be dropped.

Another trick worth noting is that in the case of the simple M-H algorithm described, using a single update for the entire state space (and not multiple component-wise updates, for example), and assuming that the same M-H kernel is used repeatedly to generate successive states of a Markov chain, then the \log \pi(\theta) term (which in the context of Bayesian inference will typically be the log posterior) will have been computed at the previous update (irrespective of whether or not the previous move was accepted). So if we are careful about how we pass on that old value to the next iteration, we can avoid recomputing the log posterior, and our algorithm will only require one log posterior evaluation per iteration rather than two. In functional programming languages it is often convenient to pass around this current log posterior density evaluation explicitly, effectively augmenting the state space of the Markov chain to include the log posterior density.

HoF for a M-H kernel

Since I’m a fan of functional programming, we will adopt a functional style throughout, and start by creating a higher-order function (HoF) that accepts a log-posterior and proposal kernel as input and returns a Metropolis kernel as output.

R

In R we can write a function to create a M-H kernel as follows.

mhKernel = function(logPost, rprop, dprop = function(new, old, ...) { 1 })
    function(x, ll) {
        prop = rprop(x)
        llprop = logPost(prop)
        a = llprop - ll + dprop(x, prop) - dprop(prop, x)
        if (log(runif(1)) < a)
            list(x=prop, ll=llprop)
        else
            list(x=x, ll=ll)
    }

Note that the kernel returned requires as input both a current state x and its associated log-posterior, ll. The new state and log-posterior densities are returned.

We need to use this transition kernel to simulate a Markov chain by successive substitution of newly simulated values back into the kernel. In more sophisticated programming languages we will use streams for this, but in R we can just use a for loop to sample values and write the states into the rows of a matrix.

mcmc = function(init, kernel, iters = 10000, thin = 10, verb = TRUE) {
    p = length(init)
    ll = -Inf
    mat = matrix(0, nrow = iters, ncol = p)
    colnames(mat) = names(init)
    x = init
    if (verb) 
        message(paste(iters, "iterations"))
    for (i in 1:iters) {
        if (verb) 
            message(paste(i, ""), appendLF = FALSE)
        for (j in 1:thin) {
            pair = kernel(x, ll)
            x = pair$x
            ll = pair$ll
            }
        mat[i, ] = x
        }
    if (verb) 
        message("Done.")
    mat
}

Then, in the context of our running logistic regression example, and using the log-posterior from the previous post, we can construct our kernel and run it as follows.

pre = c(10.0,1,1,1,1,1,5,1)
out = mcmc(init, mhKernel(lpost,
          function(x) x + pre*rnorm(p, 0, 0.02)), thin=1000)

Note the use of a symmetric proposal, so the proposal density is not required. Also note the use of a larger proposal variance for the intercept term and the second last covariate. See the full runnable script for further details.

Python

We can do something very similar to R in Python using NumPy. Our HoF for constructing a M-H kernel is

def mhKernel(lpost, rprop, dprop = lambda new, old: 1.):
    def kernel(x, ll):
        prop = rprop(x)
        lp = lpost(prop)
        a = lp - ll + dprop(x, prop) - dprop(prop, x)
        if (np.log(np.random.rand()) < a):
            x = prop
            ll = lp
        return x, ll
    return kernel

Our Markov chain runner function is

def mcmc(init, kernel, thin = 10, iters = 10000, verb = True):
    p = len(init)
    ll = -np.inf
    mat = np.zeros((iters, p))
    x = init
    if (verb):
        print(str(iters) + " iterations")
    for i in range(iters):
        if (verb):
            print(str(i), end=" ", flush=True)
        for j in range(thin):
            x, ll = kernel(x, ll)
        mat[i,:] = x
    if (verb):
        print("\nDone.", flush=True)
    return mat

We can use this code in the context of our logistic regression example as follows.

pre = np.array([10.,1.,1.,1.,1.,1.,5.,1.])

def rprop(beta):
    return beta + 0.02*pre*np.random.randn(p)

out = mcmc(init, mhKernel(lpost, rprop), thin=1000)

See the full runnable script for further details.

JAX

The above R and Python scripts are fine, but both languages are rather slow for this kind of workload. Fortunately it’s rather straightforward to convert the Python code to JAX to obtain quite amazing speed-up. We can write our M-H kernel as

def mhKernel(lpost, rprop, dprop = jit(lambda new, old: 1.)):
    @jit
    def kernel(key, x, ll):
        key0, key1 = jax.random.split(key)
        prop = rprop(key0, x)
        lp = lpost(prop)
        a = lp - ll + dprop(x, prop) - dprop(prop, x)
        accept = (jnp.log(jax.random.uniform(key1)) < a)
        return jnp.where(accept, prop, x), jnp.where(accept, lp, ll)
    return kernel

and our MCMC runner function as

def mcmc(init, kernel, thin = 10, iters = 10000):
    key = jax.random.PRNGKey(42)
    keys = jax.random.split(key, iters)
    @jit
    def step(s, k):
        [x, ll] = s
        x, ll = kernel(k, x, ll)
        s = [x, ll]
        return s, s
    @jit
    def iter(s, k):
        keys = jax.random.split(k, thin)
        _, states = jax.lax.scan(step, s, keys)
        final = [states[0][thin-1], states[1][thin-1]]
        return final, final
    ll = -np.inf
    x = init
    _, states = jax.lax.scan(iter, [x, ll], keys)
    return states[0]

There are really only two slightly tricky things about this code.

The first relates to the way JAX handles pseudo-random numbers. Since JAX is a pure functional eDSL, it can’t be used in conjunction with the typical pseudo-random number generators often used in imperative programming languages which rely on a global mutable state. This can be dealt with reasonably straightforwardly by explicitly passing around the random number state. There is a standard way of doing this that has been common practice in functional programming languages for decades. However, this standard approach is very sequential, and so doesn’t work so well in a parallel context. JAX therefore uses a splittable random number generator, where new states are created by splitting the current state into two (or more). We’ll come back to this when we get to the Haskell examples.

The second thing that might be unfamiliar to imperative programmers is the use of the scan operation (jax.lax.scan) to generate the Markov chain rather than a "for" loop. But scans are standard operations in most functional programming languages.

We can then call this code for our logistic regression example with

pre = jnp.array([10.,1.,1.,1.,1.,1.,5.,1.]).astype(jnp.float32)

@jit
def rprop(key, beta):
    return beta + 0.02*pre*jax.random.normal(key, [p])

out = mcmc(init, mhKernel(lpost, rprop), thin=1000)

See the full runnable script for further details.

Scala

In Scala we can use a similar approach to that already seen for defining a HoF to return a M-H kernel.

def mhKernel[S](
    logPost: S => Double, rprop: S => S,
    dprop: (S, S) => Double = (n: S, o: S) => 1.0
  ): ((S, Double)) => (S, Double) =
    val r = Uniform(0.0,1.0)
    state =>
      val (x0, ll0) = state
      val x = rprop(x0)
      val ll = logPost(x)
      val a = ll - ll0 + dprop(x0, x) - dprop(x, x0)
      if (math.log(r.draw()) < a)
        (x, ll)
      else
        (x0, ll0)

Note that Scala’s static typing does not prevent us from defining a function that is polymorphic in the type of the chain state, which we here call S. Also note that we are adopting a pragmatic approach to random number generation, exploiting the fact that Scala is not a pure functional language, using a mutable generator, and omitting to capture the non-determinism of the rprop function (and the returned kernel) in its type signature. In Scala this is a choice, and we could adopt a purer approach if preferred. We’ll see what such an approach will look like in Haskell, coming up next.

Now that we have the kernel, we don’t need to write an explicit runner function since Scala has good support for streaming data. There are many more-or-less sophisticated ways that we can work with data streams in Scala, and the choice depends partly on how pure one is being about tracking effects (such as non-determinism), but here I’ll just use the simple LazyList from the standard library for unfolding the kernel into an infinite MCMC chain before thinning and truncating appropriately.

  val pre = DenseVector(10.0,1.0,1.0,1.0,1.0,1.0,5.0,1.0)
  def rprop(beta: DVD): DVD = beta + pre *:* (DenseVector(Gaussian(0.0,0.02).sample(p).toArray))
  val kern = mhKernel(lpost, rprop)
  val s = LazyList.iterate((init, -Inf))(kern) map (_._1)
  val out = s.drop(150).thin(1000).take(10000)

See the full runnable script for further details.

Haskell

Since Haskell is a pure functional language, we need to have some convention regarding pseudo-random number generation. Haskell supports several styles. The most commonly adopted approach wraps a mutable generator up in a monad. The typical alternative is to use a pure functional generator and either explicitly thread the state through code or hide this in a monad similar to the standard approach. However, Haskell also supports the use of splittable generators, so we can consider all three approaches for comparative purposes. The approach taken does affect the code and the type signatures, and even the streaming data abstractions most appropriate for chain generation.

Starting with a HoF for producing a Metropolis kernel, an approach using the standard monadic generators could like like

mKernel :: (StatefulGen g m) => (s -> Double) -> (s -> g -> m s) -> 
           g -> (s, Double) -> m (s, Double)
mKernel logPost rprop g (x0, ll0) = do
  x <- rprop x0 g
  let ll = logPost(x)
  let a = ll - ll0
  u <- (genContVar (uniformDistr 0.0 1.0)) g
  let next = if ((log u) < a)
        then (x, ll)
        else (x0, ll0)
  return next

Note how non-determinism is captured in the type signatures by the monad m. The explicit pure approach is to thread the generator through non-deterministic functions.

mKernelP :: (RandomGen g) => (s -> Double) -> (s -> g -> (s, g)) -> 
            g -> (s, Double) -> ((s, Double), g)
mKernelP logPost rprop g (x0, ll0) = let
  (x, g1) = rprop x0 g
  ll = logPost(x)
  a = ll - ll0
  (u, g2) = uniformR (0, 1) g1
  next = if ((log u) < a)
        then (x, ll)
        else (x0, ll0)
  in (next, g2)

Here the updated random number generator state is returned from each non-deterministic function for passing on to subsequent non-deterministic functions. This explicit sequencing of operations makes it possible to wrap the generator state in a state monad giving code very similar to the stateful monadic generator approach, but as already discussed, the sequential nature of this approach makes it unattractive in parallel and concurrent settings.

Fortunately the standard Haskell pure generator is splittable, meaning that we can adopt a splitting approach similar to JAX if we prefer, since this is much more parallel-friendly.

mKernelP :: (RandomGen g) => (s -> Double) -> (s -> g -> s) -> 
            g -> (s, Double) -> (s, Double)
mKernelP logPost rprop g (x0, ll0) = let
  (g1, g2) = split g
  x = rprop x0 g1
  ll = logPost(x)
  a = ll - ll0
  u = unif g2
  next = if ((log u) < a)
        then (x, ll)
        else (x0, ll0)
  in next

Here non-determinism is signalled by passing a generator state (often called a "key" in the context of splittable generators) into a function. Functions receiving a key are responsible for splitting it to ensure that no key is ever used more than once.

Once we have a kernel, we need to unfold our Markov chain. When using the monadic generator approach, it is most natural to unfold using a monadic stream

mcmc :: (StatefulGen g m) =>
  Int -> Int -> s -> (g -> s -> m s) -> g -> MS.Stream m s
mcmc it th x0 kern g = MS.iterateNM it (stepN th (kern g)) x0

stepN :: (Monad m) => Int -> (a -> m a) -> (a -> m a)
stepN n fa = if (n == 1)
  then fa
  else (\x -> (fa x) >>= (stepN (n-1) fa))

whereas for the explicit approaches it is more natural to unfold into a regular infinite data stream. So, for the explicit sequential approach we could use

mcmcP :: (RandomGen g) => s -> (g -> s -> (s, g)) -> g -> DS.Stream s
mcmcP x0 kern g = DS.unfold stepUf (x0, g)
  where
    stepUf xg = let
      (x1, g1) = kern (snd xg) (fst xg)
      in (x1, (x1, g1))

and with the splittable approach we could use

mcmcP :: (RandomGen g) =>
  s -> (g -> s -> s) -> g -> DS.Stream s
mcmcP x0 kern g = DS.unfold stepUf (x0, g)
  where
    stepUf xg = let
      (x1, g1) = xg
      x2 = kern g1 x1
      (g2, _) = split g1
      in (x2, (x2, g2))

Calling these functions for our logistic regression example is similar to what we have seen before, but again there are minor syntactic differences depending on the approach. For further details see the full runnable scripts for the monadic approach, the pure sequential approach, and the splittable approach.

Dex

Dex is a pure functional language and uses a splittable random number generator, so the style we use is similar to JAX (or Haskell using a splittable generator). We can generate a Metropolis kernel with

def mKernel {s} (lpost: s -> Float) (rprop: Key -> s -> s) : 
    Key -> (s & Float) -> (s & Float) =
  def kern (k: Key) (sll: (s & Float)) : (s & Float) =
    (x0, ll0) = sll
    [k1, k2] = split_key k
    x = rprop k1 x0
    ll = lpost x
    a = ll - ll0
    u = rand k2
    select (log u < a) (x, ll) (x0, ll0)
  kern

We can then unfold our Markov chain with

def markov_chain {s} (k: Key) (init: s) (kern: Key -> s -> s) (its: Nat) :
    Fin its => s =
  with_state init \st.
    for i:(Fin its).
      x = kern (ixkey k i) (get st)
      st := x
      x

Here we combine Dex’s state effect with a for loop to unfold the stream. See the full runnable script for further details.

Next steps

As previously discussed, none of these codes are optimised, so care should be taken not to over-interpret running times. However, JAX and Dex are noticeably faster than the alternatives, even running on a single CPU core. Another interesting feature of both JAX and Dex is that they are differentiable. This makes it very easy to develop algorithms using gradient information. In subsequent posts we will think about the gradient of our example log-posterior and how we can use gradient information to develop "better" sampling algorithms.

The complete runnable scripts are all available from this public github repo.

Bayesian inference for a logistic regression model (Part 2)

Part 2: The log posterior

Introduction

This is the second part in a series of posts on MCMC-based Bayesian inference for a logistic regression model. If you are new to this series, please go back to Part 1.

In the previous post we looked at the basic modelling concepts, and how to fit the model using a variety of PPLs. In this post we will prepare for doing MCMC by considering the problem of computing the unnormalised log posterior for the model. We will then see how this posterior can be implemented in several different languages and libraries.

Derivation

Basic structure

In Bayesian inference the posterior distribution is just the conditional distribution of the model parameters given the data, and therefore proportional to the joint distribution of the model and data. We often write this as

\displaystyle \pi(\theta|y) \propto \pi(\theta,y) = \pi(\theta)\pi(y|\theta).

Taking logs we have

\displaystyle \log \pi(\theta, y) = \log \pi(\theta) + \log \pi(y|\theta).

So (up to an additive constant) the log posterior is just the sum of the log prior and log likelihood. There are many good (numerical) reasons why we try to work exclusively with the log posterior and try to avoid ever evaluating the raw posterior density.

For our example logistic regression model, the parameter vector \theta is just the vector of regression coefficients, \beta. We assumed independent mean zero normal priors for the components of this vector, so the log prior is just the sum of logs of normal densities. Many scientific libraries will have built-in functions for returning the log-pdf of standard distributions, but if an explicit form is required, the log of the density of a N(0,\sigma^2) at x is just

\displaystyle -\log(2\pi)/2 - \log|\sigma| - x^2/(2\sigma^2),

and the initial constant term normalising the density can often be dropped.

Log-likelihood (first attempt)

Information from the data comes into the log posterior via the log-likelihood. The typical way to derive the likelihood for problems of this type is to assume the usual binary encoding of the data (success 1, failure 0). Then, for a Bernoulli observation with probability p_i,\ i=1,\ldots,n, the likelihood associated with observation y_i is

\displaystyle f(y_i|p_i) = \left[ \hphantom{1-}p_i \quad :\ y_i=1 \atop 1-p_i \quad :\ y_i=0 \right. \quad = \quad p_i^{y_i}(1-p_i)^{1-y_i}.

Taking logs and then switching to parameter \eta_i=\text{logit}(p_i) we have

\displaystyle \log f(y_i|\eta_i) = y_i\eta_i - \log(1+e^{\eta_i}),

and summing over n observations gives the log likelihood

\displaystyle \log\pi(y|\eta) \equiv \ell(\eta;y) = y\cdot \eta - \mathbf{1}\cdot\log(\mathbf{1}+\exp{\eta}).

In the context of logistic regression, \eta is the linear predictor, so \eta=X\beta, giving

\displaystyle \ell(\beta;y) = y^\textsf{T}X\beta - \mathbf{1}^\textsf{T}\log(\mathbf{1}+\exp[X\beta]).

This is a perfectly good way of expressing the log-likelihood, and we will come back to it later when we want the gradient of the log-likelihood, but it turns out that there is a similar-but-different way of deriving it that results in an expression that is equivalent but slightly cheaper to evaluate.

Log-likelihood (second attempt)

For our second attempt, we will assume that the data is coded in a different way. Instead of the usual binary encoding, we will assume that the observation \tilde y_i is 1 for success and -1 for failure. This isn’t really a problem, since the two encodings are related by \tilde y_i = 2y_i-1. This new encoding is convenient in the context of a logit parameterisation since then

\displaystyle f(y_i|\eta_i) = \left[ p_i \ :\ \tilde y_i=1\atop 1-p_i\ :\ \tilde y_i=-1 \right. \ = \ \left[ (1+e^{-\eta_i})^{-1} \ :\ \tilde y_i=1\atop (1+e^{\eta_i})^{-1} \ :\ \tilde y_i=-1 \right. \ = \ (1+e^{-\tilde y_i\eta_i})^{-1} ,

and hence

\displaystyle \log f(y_i|\eta_i) = -\log(1+e^{-\tilde y_i\eta_i}).

Summing over observations gives

\displaystyle \ell(\eta;\tilde y) = -\mathbf{1}\cdot \log(\mathbf{1}+\exp[-\tilde y\circ \eta]),

where \circ denotes the Hadamard product. Substituting \eta=X\beta gives the log-likelihood

\displaystyle \ell(\beta;\tilde y) = -\mathbf{1}^\textsf{T} \log(\mathbf{1}+\exp[-\tilde y\circ X\beta]).

This likelihood is a bit cheaper to evaluate that the one previously derived. If we prefer to write in terms of the original data encoding, we can obviously do so as

\displaystyle \ell(\beta; y) = -\mathbf{1}^\textsf{T} \log(\mathbf{1}+\exp[-(2y-\mathbf{1})\circ (X\beta)]),

and in practice, it is this version that is typically used. To be clear, as an algebraic function of \beta and y the two functions are different. But they coincide for binary vectors y which is all that matters.

Implementation

R

In R we can create functions for evaluating the log-likelihood, log-prior and log-posterior as follows (assuming that X and y are in scope).

ll = function(beta)
    sum(-log(1 + exp(-(2*y - 1)*(X %*% beta))))

lprior = function(beta)
    dnorm(beta[1], 0, 10, log=TRUE) + sum(dnorm(beta[-1], 0, 1, log=TRUE))

lpost = function(beta) ll(beta) + lprior(beta)

Python

In Python (with NumPy and SciPy) we can define equivalent functions with

def ll(beta):
    return np.sum(-np.log(1 + np.exp(-(2*y - 1)*(X.dot(beta)))))

def lprior(beta):
    return (sp.stats.norm.logpdf(beta[0], loc=0, scale=10) + 
            np.sum(sp.stats.norm.logpdf(beta[range(1,p)], loc=0, scale=1)))

def lpost(beta):
    return ll(beta) + lprior(beta)

JAX

Python, like R, is a dynamic language, and relatively slow for MCMC algorithms. JAX is a tensor computation framework for Python that embeds a pure functional differentiable array processing language inside Python. JAX can JIT-compile high-performance code for both CPU and GPU, and has good support for parallelism. It is rapidly becoming the preferred way to develop high-performance sampling algorithms within the Python ecosystem. We can encode our log-posterior in JAX as follows.

@jit
def ll(beta):
    return jnp.sum(-jnp.log(1 + jnp.exp(-(2*y - 1)*jnp.dot(X, beta))))

@jit
def lprior(beta):
    return (jsp.stats.norm.logpdf(beta[0], loc=0, scale=10) + 
            jnp.sum(jsp.stats.norm.logpdf(beta[jnp.array(range(1,p))], loc=0, scale=1)))

@jit
def lpost(beta):
    return ll(beta) + lprior(beta)

Scala

JAX is a pure functional programming language embedded in Python. Pure functional programming languages are intrinsically more scalable and compositional than imperative languages such as R and Python, and are much better suited to exploit concurrency and parallelism. I’ve given a bunch of talks about this recently, so if you are interested in this, perhaps start with the materials for my Laplace’s Demon talk. Scala and Haskell are arguably the current best popular general purpose functional programming languages, so it is possibly interesting to consider the use of these languages for the development of scalable statistical inference codes. Since both languages are statically typed compiled functional languages with powerful type systems, they can be highly performant. However, neither is optimised for numerical (tensor) computation, so you should not expect that they will have performance comparable with optimised tensor computation frameworks such as JAX. We can encode our log-posterior in Scala (with Breeze) as follows:

  def ll(beta: DVD): Double =
      sum(-log(ones + exp(-1.0*(2.0*y - ones)*:*(X * beta))))

  def lprior(beta: DVD): Double =
    Gaussian(0,10).logPdf(beta(0)) + 
      sum(beta(1 until p).map(Gaussian(0,1).logPdf(_)))

  def lpost(beta: DVD): Double = ll(beta) + lprior(beta)

Spark

Apache Spark is a Scala library for distributed "big data" processing on clusters of machines. Despite fundamental differences, there is a sense in which Spark for Scala is a bit analogous to JAX for Python: both Spark and JAX are concerned with scalability, but they are targeting rather different aspects of scalability: JAX is concerned with getting regular sized data processing algorithms to run very fast (on GPUs), whereas Spark is concerned with running huge data processing tasks quickly by distributing work over clusters of machines. Despite obvious differences, the fundamental pure functional computational model adopted by both systems is interestingly similar: both systems are based on lazy transformations of immutable data structures using pure functions. This is a fundamental pattern for scalable data processing transcending any particular language, library or framework. We can encode our log posterior in Spark as follows.

    def ll(beta: DVD): Double = 
      df.map{row =>
        val y = row.getAs[Double](0)
        val x = BDV.vertcat(BDV(1.0),toBDV(row.getAs[DenseVector](1)))
        -math.log(1.0 + math.exp(-1.0*(2.0*y-1.0)*(x.dot(beta))))}.reduce(_+_)
    def lprior(beta: DVD): Double =
      Gaussian(0,10).logPdf(beta(0)) +
        sum(beta(1 until p).map(Gaussian(0,1).logPdf(_)))
    def lpost(beta: DVD): Double =
      ll(beta) + lprior(beta)

Haskell

Haskell is an old, lazy pure functional programming language with an advanced type system, and remains the preferred language for the majority of functional programming language researchers. Hmatrix is the standard high performance numerical linear algebra library for Haskell, so we can use it to encode our log-posterior as follows.

ll :: Matrix Double -> Vector Double -> Vector Double -> Double
ll x y b = (negate) (vsum (cmap log (
                              (scalar 1) + (cmap exp (cmap (negate) (
                                                         (((scalar 2) * y) - (scalar 1)) * (x #> b)
                                                         )
                                                     )))))

pscale :: [Double] -- prior standard deviations
pscale = [10.0, 1, 1, 1, 1, 1, 1, 1]
lprior :: Vector Double -> Double
lprior b = sum $  (\x -> logDensity (normalDistr 0.0 (snd x)) (fst x)) <$> (zip (toList b) pscale)
           
lpost :: Matrix Double -> Vector Double -> Vector Double -> Double
lpost x y b = (ll x y b) + (lprior b)

Again, a reminder that, here and elsewhere, there are various optimisations could be done that I’m not bothering with. This is all just proof-of-concept code.

Dex

JAX proves that a pure functional DSL for tensor computation can be extremely powerful and useful. But embedding such a language in a dynamic imperative language like Python has a number of drawbacks. Dex is an experimental statically typed stand-alone DSL for differentiable array and tensor programming that attempts to combine some of the correctness and composability benefits of powerful statically typed functional languages like Scala and Haskell with the performance benefits of tensor computation systems like JAX. It is currently rather early its development, but seems very interesting, and is already quite useable. We can encode our log-posterior in Dex as follows.

def ll (b: (Fin 8)=>Float) : Float =
  neg $  sum (log (map (\ x. (exp x) + 1) ((map (\ yi. 1 - 2*yi) y)*(x **. b))))

pscale = [10.0, 1, 1, 1, 1, 1, 1, 1] -- prior SDs
prscale = map (\ x. 1.0/x) pscale

def lprior (b: (Fin 8)=>Float) : Float =
  bs = b*prscale
  neg $  sum ((log pscale) + (0.5 .* (bs*bs)))

def lpost (b: (Fin 8)=>Float) : Float =
  (ll b) + (lprior b)

Next steps

Now that we have a way of evaluating the log posterior, we can think about constructing Markov chains having the posterior as their equilibrium distribution. In the next post we will look at one of the simplest ways of doing this: the Metropolis algorithm.

Complete runnable scripts are available from this public github repo.

Bayesian inference for a logistic regression model (Part 1)

Part 1: The basics

Introduction

This is the first in a series of posts on MCMC-based fully Bayesian inference for a logistic regression model. In this series we will look at the model, and see how the posterior distribution can be sampled using a variety of different programming languages and libraries.

Logistic regression

Logistic regression is concerned with predicting a binary outcome based on some covariate information. The probability of "success" is modelled via a logistic transformation of a linear predictor constructed from the covariate vector.

This is a very simple model, but is a convenient toy example since it is arguably the simplest interesting example of an intractable (nonlinear) statistical model requiring some kind of iterative numerical fitting method, even in the non-Bayesian setting. In a Bayesian context, the posterior distribution is intractable, necessitating either approximate or computationally intensive numerical methods of "solution". In this series of posts, we will mainly concentrate on MCMC algortithms for sampling the full posterior distribution of the model parameters given some observed data.

We assume n observations and p covariates (including an intercept term that is always 1). The binary observations y_i,\ i=1,\ldots,n are 1 for a "success" and 0 for a "failure". The covariate p-vectors x_i, i=1,\ldots,n all have 1 as the first element. The statistical model is

\displaystyle \text{logit}\left(\text{Pr}[Y_i = 1]\right) = x_i \cdot \beta,\quad i=1,\ldots,n,

where \beta is a p-vector of parameters, a\cdot b = a^\textsf{T} b, and

\displaystyle \text{logit}(q) \equiv \log\left(\frac{q}{1-q}\right),\quad \forall\ q\in (0,1).

Equivalently,

\displaystyle \text{Pr}[Y_i = 1] = \text{expit}(x_i \cdot \beta),\quad i=1,\ldots,n,

where

\displaystyle \text{expit}(\theta) \equiv \frac{1}{1+e^{-\theta}},\quad \forall\ \theta\in\mathbb{R}.

Note that the expit function is sometimes called the logistic or sigmoid function, but expit is slightly less ambiguous. The statistical problem is to choose the parameter vector \beta to provide the "best" model for the probability of success. In the Bayesian setting, a prior distribution (typically multivariate normal) is specified for \beta, and then the posterior distribution after conditioning on the data is the object of inferential interest.

Example problem

In order to illustrate the ideas, it is useful to have a small running example. Here we will use the (infamous) Pima training dataset (MASS::Pima.tr in R). Here there are n=200 observations and 7 predictors. Adding an intercept gives p=8 covariates. For the Bayesian analysis, we need a prior on \beta. We will assume independent mean zero normal distributions for each component. The prior standard deviation for the intercept will be 10 and for the other covariates will be 1.

Describing the model in some PPLs

In this first post in the series, we will use probabilistic programming languages (PPLs) to describe the model and sample the posterior distribution.

JAGS

JAGS is a stand-alone domain specific language (DSL) for probabilistic programming. It can be used independently of general purpose programming languages, or called from popular languages for data science such as Python and R. We can describe our model in JAGS with the following code.

  model {
    for (i in 1:n) {
      y[i] ~ dbern(pr[i])
      logit(pr[i]) <- inprod(X[i,], beta)
    }
    beta[1] ~ dnorm(0, 0.01)
    for (j in 2:p) {
      beta[j] ~ dnorm(0, 1)
    }
  }

Note that JAGS uses precision as the second parameter of a normal distribution. See the full runnable R script for further details. Given this model description, JAGS can construct an MCMC sampler for the posterior distribution of the model parameters given the data. See the full script for how to feed in the data, run the sampler, and analyse the output.

Stan

Stan is another stand-alone DSL for probabilistic programming, and has a very sophisticated sampling algorithm, making it a popular choice for non-trivial models. It uses gradient information for sampling, and therefore requires a differentiable log-posterior. We could encode our logistic regression model as follows.

data {
  int<lower=1> n;
  int<lower=1> p;
  int<lower=0, upper=1> y[n];
  real X[n,p];
}
parameters {
  real beta[p];
}
model {
  for (i in 1:n) {
    real eta = dot_product(beta, X[i,]);
    real pr = 1/(1+exp(-eta));
    y[i] ~ binomial(1, pr);
  }
  beta[1] ~ normal(0, 10);
  for (j in 2:p) {
    beta[j] ~ normal(0, 1);
  }
}

Note that Stan uses standard deviation as the second parameter of the normal distribution. See the full runnable R script for further details.

PyMC

JAGS and Stan are stand-alone DSLs for probabilistic programming. This has the advantage of making them independent of any particular host (general purpose) programming language. But it also means that they are not able to take advantage of the language and tool support of an existing programming language. An alternative to stand-alone DSLs are embedded DSLs (eDSLs). Here, a DSL is embedded as a library or package within an existing (general purpose) programming language. Then, ideally, in the context of PPLs, probabilistic programs can become ordinary values within the host language, and this can have some advantages, especially if the host language is sophisticated. A number of probabilistic programming languages have been implemented as eDSLs in Python. Python is not a particularly sophisticated language, so the advantages here are limited, but not negligible.

PyMC is probably the most popular eDSL PPL in Python. We can encode our model in PyMC as follows.

pscale = np.array([10.,1.,1.,1.,1.,1.,1.,1.])
with pm.Model() as model:
    beta = pm.Normal('beta', 0, pscale, shape=p)
    eta = pmm.matrix_dot(X, beta)
    pm.Bernoulli('y', logit_p=eta, observed=y)
    traces = pm.sample(2500, tune=1000, init="adapt_diag", return_inferencedata=True)

See the full runnable script for further details.

NumPyro

NumPyro is a fork of Pyro for NumPy and JAX (of which more later). We can encode our model with NumPyro as follows.

pscale = jnp.array([10.,1.,1.,1.,1.,1.,1.,1.]).astype(jnp.float32)
def model(X, y):
    coefs = numpyro.sample("beta", dist.Normal(jnp.zeros(p), pscale))
    logits = jnp.dot(X, coefs)
    return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=y)

See the full runnable script for further details.

Please note that none of the above examples have been optimised, or are even necessarily expressed idiomatically within each PPL. I’ve just tried to express the model in as simple and similar way across each PPL. For example, I know that there is a function bernoulli_logit_glm in Stan which would simplify the model and improve sampling efficiency, but I’ve deliberately not used it in order to try and keep the implementations as basic as possible. The same will be true for all of the code examples in this series of blog posts. The code has not been optimised and should therefore not be used for serious benchmarking.

Next steps

PPLs are convenient, and are becoming increasingly sophisticated. Each of the above PPLs provides a simple way to pass in observed data, and automatically construct an MCMC algorithm for sampling from the implied posterior distribution – see the full scripts for details. All of the PPLs work well for this problem, and all produce identical results up to Monte Carlo error. Each PPL has its own approach to sampler construction, and some PPLs offer multiple choices. However, more challenging problems often require highly customised samplers. Such samplers will often need to be created from scratch, and will require (at least) the ability to compute the (unnormalised log) posterior density at proposed parameter values, so in the next post we will look at how this can be derived for this model (in a couple of different ways) and coded up from scratch in a variety of programming languages.

All of the complete, runnable code associated with this series of blog posts can be obtained from this public github repo.

scala-glm: Regression modelling in Scala

Introduction

As discussed in the previous post, I’ve recently constructed and delivered a short course on statistical computing with Scala. Much of the course is concerned with writing statistical algorithms in Scala, typically making use of the scientific and numerical computing library, Breeze. Breeze has all of the essential tools necessary for building statistical algorithms, but doesn’t contain any higher level modelling functionality. As part of the course, I walked through how to build a small library for regression modelling on top of Breeze, including all of the usual regression diagnostics (such as standard errors, t-statistics, p-values, F-statistics, etc.). While preparing the course materials it occurred to me that it would be useful to package and document this code properly for general use. In advance of the course I packaged the code up into a bare-bones library, but since then I’ve fleshed it out, tidied it up and documented it properly, so it’s now ready for people to use.

The library covers PCA, linear regression modelling and simple one-parameter GLMs (including logistic and Poisson regression). The underlying algorithms are fairly efficient and numerically stable (eg. linear regression uses the QR decomposition of the model matrix, and the GLM fitting uses QR within each IRLS step), though they are optimised more for clarity than speed. The library also includes a few utility functions and procedures, including a pairs plot (scatter-plot matrix).

A linear regression example

Plenty of documentation is available from the scala-glm github repo which I won’t repeat here. But to give a rough idea of how things work, I’ll run through an interactive session for the linear regression example.

First, download a dataset from the UCI ML Repository to disk for subsequent analysis (caching the file on disk is good practice, as it avoids unnecessary load on the UCI server, and allows running the code off-line):

import scalaglm._
import breeze.linalg._

val url = "http://archive.ics.uci.edu/ml/machine-learning-databases/00291/airfoil_self_noise.dat"
val fileName = "self-noise.csv"

// download the file to disk if it hasn't been already
val file = new java.io.File(fileName)
if (!file.exists) {
  val s = new java.io.PrintWriter(file)
  val data = scala.io.Source.fromURL(url).getLines
  data.foreach(l => s.write(l.trim.
    split('\t').filter(_ != "").
    mkString("", ",", "\n")))
  s.close
}

Once we have a CSV file on disk, we can load it up and look at it.

val mat = csvread(new java.io.File(fileName))
// mat: breeze.linalg.DenseMatrix[Double] =
// 800.0    0.0  0.3048  71.3  0.00266337  126.201
// 1000.0   0.0  0.3048  71.3  0.00266337  125.201
// 1250.0   0.0  0.3048  71.3  0.00266337  125.951
// ...
println("Dim: " + mat.rows + " " + mat.cols)
// Dim: 1503 6
val figp = Utils.pairs(mat, List("Freq", "Angle", "Chord", "Velo", "Thick", "Sound"))
// figp: breeze.plot.Figure = breeze.plot.Figure@37718125

We can then regress the response in the final column on the other variables.

val y = mat(::, 5) // response is the final column
// y: DenseVector[Double] = DenseVector(126.201, 125.201, ...
val X = mat(::, 0 to 4)
// X: breeze.linalg.DenseMatrix[Double] =
// 800.0    0.0  0.3048  71.3  0.00266337
// 1000.0   0.0  0.3048  71.3  0.00266337
// 1250.0   0.0  0.3048  71.3  0.00266337
// ...
val mod = Lm(y, X, List("Freq", "Angle", "Chord", "Velo", "Thick"))
// mod: scalaglm.Lm =
// Lm(DenseVector(126.201, 125.201, ...
mod.summary
// Estimate	 S.E.	 t-stat	p-value		Variable
// ---------------------------------------------------------
// 132.8338	 0.545	243.866	0.0000 *	(Intercept)
//  -0.0013	 0.000	-30.452	0.0000 *	Freq
//  -0.4219	 0.039	-10.847	0.0000 *	Angle
// -35.6880	 1.630	-21.889	0.0000 *	Chord
//   0.0999	 0.008	12.279	0.0000 *	Velo
// -147.3005	15.015	-9.810	0.0000 *	Thick
// Residual standard error:   4.8089 on 1497 degrees of freedom
// Multiple R-squared: 0.5157, Adjusted R-squared: 0.5141
// F-statistic: 318.8243 on 5 and 1497 DF, p-value: 0.00000
val fig = mod.plots
// fig: breeze.plot.Figure = breeze.plot.Figure@60d7ebb0

There is a .predict method for generating point predictions (and standard errors) given a new model matrix, and fitting GLMs is very similar – these things are covered in the quickstart guide for the library.

Summary

scala-glm is a small Scala library built on top of the Breeze numerical library which enables simple and convenient regression modelling in Scala. It is reasonably well documented and usable in its current form, but I intend to gradually add additional features according to demand as time permits.

Statistical computing with Scala free on-line course

I’ve recently delivered a three-day intensive short-course on Scala for statistical computing and data science. The course seemed to go well, and the experience has convinced me that Scala should be used a lot more by statisticians and data scientists for a range of problems in statistical computing. In particular, the simplicity of writing fast efficient parallel algorithms is reason alone to take a careful look at Scala. With a view to helping more statisticians get to grips with Scala, I’ve decided to freely release all of the essential materials associated with the course: the course notes (as PDF), code fragments, complete examples, end-of-chapter exercises, etc. Although I developed the materials with the training course in mind, the course notes are reasonably self-contained, making the course quite suitable for self-study. At some point I will probably flesh out the notes into a proper book, but that will probably take me a little while.

I’ve written a brief self-study guide to point people in the right direction. For people studying the material in their spare time, the course is probably best done over nine weeks (one chapter per week), and this will then cover material at a similar rate to a typical MOOC.

The nine chapters are:

1. Introduction
2. Scala and FP Basics
3. Collections
4. Scala Breeze
5. Monte Carlo
6. Statistical modelling
7. Tools
8. Apache Spark
9. Advanced topics

For anyone frustrated by the limitations of dynamic languages such as R, Python or Octave, this course should provide a good pathway to an altogether more sophisticated, modern programming paradigm.

Books on Scala for statistical computing and data science

Introduction

People regularly ask me about books and other resources for getting started with Scala for statistical computing and data science. This post will focus on books, but it’s worth briefly noting that there are a number of other resources available, on-line and otherwise, that are also worth considering. I particularly like the Coursera course Functional Programming Principles in Scala – I still think this is probably the best way to get started with Scala and functional programming for most people. In fact, there is an entire Functional Programming in Scala Specialization that is worth considering – I’ll probably discuss that more in another post. I’ve got a draft page of Scala links which has a bias towards scientific and statistical computing, and I’m currently putting together a short course in that area, which I’ll also discuss further in future posts. But this post will concentrate on books.

Reading list

Getting started with Scala

Before one can dive into statistical computing and data science using Scala, it’s a good idea to understand a bit about the language and about functional programming. There are by now many books on Scala, and I haven’t carefully reviewed all of them, but I’ve looked at enough to have an idea about good ways of getting started.

  • Programming in Scala: Third edition, Odersky et al, Artima.
    • This is the Scala book, often referred to on-line as PinS. It is a weighty tome, and works through the Scala language in detail, starting from the basics. Every serious Scala programmer should own this book. However, it isn’t the easiest introduction to the language.
  • Scala for the Impatient, Horstmann, Addison-Wesley.
    • As the name suggests, this is a much quicker and easier introduction to Scala than PinS, but assumes reasonable familiarity with programming in general, and sort-of assumes that the reader has a basic knowledge of Java and the JVM ecosystem. That said, it does not assume that the reader is a Java expert. My feeling is that for someone who has a reasonable programming background and a passing familiarity with Java, then this book is probably the best introduction to the language. Note that there is a second edition in the works.
  • Functional Programming in Scala Chiusano and Bjarnason, Manning.
    • It is possible to write Scala code in the style of "Java-without-the-semi-colons", but really the whole point of Scala is to move beyond that kind of Object-Oriented programming style. How much you venture down the path towards pure Functional Programming is very much a matter of taste, but many of the best Scala programmers are pretty hard-core FP, and there’s probably a reason for that. But many people coming to Scala don’t have a strong FP background, and getting up to speed with strongly-typed FP isn’t easy for people who only know an imperative (Object-Oriented) style of programming. This is the book that will help you to make the jump to FP. Sometimes referred to online as FPiS, or more often even just as the red book, this is also a book that every serious Scala programmer should own (and read!). Note that is isn’t really a book about Scala – it is a book about strongly typed FP that just "happens" to use Scala for illustrating the ideas. Consequently, you will probably want to augment this book with a book that really is about Scala, such as one of the books above. Since this is the first book on the list published by Manning, I should also mention how much I like computing books from this publisher. They are typically well-produced, and their paper books (pBooks) come with complimentary access to well-produced DRM-free eBook versions, however you purchase them.
  • Functional and Reactive Domain Modeling, Ghosh, Manning.
    • This is another book that isn’t really about Scala, but about software engineering using a strongly typed FP language. But again, it uses Scala to illustrate the ideas, and is an excellent read. You can think of it as a more practical "hands-on" follow-up to the red book, which shows how the ideas from the red book translate into effective solutions to real-world problems.
  • Structure and Interpretation of Computer Programs, second edition Abelson et al, MIT Press.
    • This is not a Scala book! This is the only book in this list which doesn’t use Scala at all. I’ve included it on the list because it is one of the best books on programming that I’ve read, and is the book that I wish someone had told me about 20 years ago! In fact the book uses Scheme (a Lisp derivative) as the language to illustrate the ideas. There are obviously important differences between Scala and Scheme – eg. Scala is strongly statically typed and compiled, whereas Scheme is dynamically typed and interpreted. However, there are also similarities – eg. both languages support and encourage a functional style of programming but are not pure FP languages. Referred to on-line as SICP this book is a classic. Note that there is no need to buy a paper copy if you like eBooks, since electronic versions are available free on-line.

Scala for statistical computing and data science

  • Scala for Data Science, Bugnion, Packt.
    • Not to be confused with the (terrible) book, Scala for machine learning by the same publisher. Scala for Data Science is my top recommendation for getting started with statistical computing and data science applications using Scala. I have reviewed this book in another post, so I won’t say more about it here (but I like it).
  • Scala Data Analysis Cookbook, Manivannan, Packt.
    • I’m not a huge fan of the cookbook format, but this book is really mis-named, as it isn’t really a cookbook and isn’t really about data analysis in Scala! It is really a book about Apache Spark, and proceeds fairly sequentially in the form of a tutorial introduction to Spark. Spark is an impressive piece of technology, and it is obviously one of the factors driving interest in Scala, but it’s important to understand that Spark isn’t Scala, and that many typical data science applications will be better tackled using Scala without Spark. I’ve not read this book cover-to-cover as it offers little over Scala for Data Science, but its coverage of Spark is a bit more up-to-date than the Spark books I mention below, so it could be of interest to those who are mainly interested in Scala for Spark.
  • Scala High Performance Programming, Theron and Diamant, Packt.
    • This is an interesting book, fundamentally about developing high performance streaming data processing algorithm pipelines in Scala. It makes no reference to Spark. The running application is an on-line financial trading system. It takes a deep dive into understanding performance in Scala and on the JVM, and looks at how to benchmark and profile performance, diagnose bottlenecks and optimise code. This is likely to be of more interest to those interested in developing efficient algorithms for scientific and statistical computing rather than applied data scientists, but it covers some interesting material not covered by any of the other books in this list.
  • Learning Spark, Karau et al, O’Reilly.
    • This book provides an introduction to Apache Spark, written by some of the people who developed it. Spark is a big data analytics framework built on top of Scala. It is arguably the best available framework for big data analytics on computing clusters in the cloud, and hence there is a lot of interest in it. The book is a perfectly good introduction to Spark, and shows most examples implemented using the Java and Python APIs in addition to the canonical Scala (Spark Shell) implementation. This is useful for people working with multiple languages, but can be mildly irritating to anyone who is only interested in Scala. However, the big problem with this (and every other) book on Spark is that Spark is evolving very quickly, and so by the time any book on Spark is written and published it is inevitably very out of date. It’s not clear that it is worth buying a book specifically about Spark at this stage, or whether it would be better to go for a book like Scala for Data Science, which has a couple of chapters of introduction to Spark, which can then provide a starting point for engaging with Spark’s on-line documentation (which is reasonably good).
  • Advanced Analytics with Spark, Ryza et al, O’Reilly.
    • This book has a bit of a "cookbook" feel to it, which some people like and some don’t. It’s really more like an "edited volume" with different chapters authored by different people. Unlike Learning Spark it focuses exclusively on the Scala API. The book basically covers the development of a bunch of different machine learning pipelines for a variety of applications. My main problem with this book is that it has aged particularly badly, as all of the pipelines are developed with raw RDDs, which isn’t how ML pipelines in Spark are constructed any more. So again, it’s difficult for me to recommend. The message here is that if you are thinking of buying a book about Spark, check very carefully when it was published and what version of Spark it covers and whether that is sufficiently recent to be of relevance to you.

Summary

There are lots of books to get started with Scala for statistical computing and data science applications. My "bare minimum" recommendation would be some generic Scala book (doesn’t really matter which one), the red book, and Scala for data science. After reading those, you will be very well placed to top-up your knowledge as required with on-line resources.

Scala for Data Science [book review]

This post will review the book:

Disclaimer: This book review has not been solicited by the publisher (or anyone else) in any way. I purchased the review copy of this book myself. I have not received any benefit from the writing of this review.

Introduction

On this blog I previously reviewed the (terrible) book, Scala for machine learning by the same publisher. I was therefore rather wary of buying this book. But the topic coverage looked good, so I decided to buy it, and wasn’t disappointed. Scala for Data Science is my top recommendation for getting started with statistical computing and data science applications using Scala.

Overview

The book assumes a basic familiarity with programming in Scala, at around the level of someone who has completed the Functional Programming Principles in Scala Coursera course. That is, it (quite sensibly) doesn’t attempt to teach the reader how to program in Scala, but rather how to approach the development of data science applications using Scala. It introduces more advanced Scala idioms gradually (eg. typeclasses don’t appear until Chapter 5), so it is relatively approachable for those who aren’t yet Scala experts. The book does cover Apache Spark, but Spark isn’t introduced until Chapter 10, so it isn’t “just another Spark book”. Most of the book is about developing data science applications in Scala, completely independently of Spark. That said, it also provides one of the better introductions to Spark, so doubles up as a pretty good introductory Spark book, in addition to being a good introduction to the development of data science applications with Scala. It should probably be emphasised that the book is very much focused on data science, rather than statistical computing, but there is plenty of material of relevance to those who are more interested in statistical computing than applied data science.

Chapter by chapter

  1. Scala and Data Science – motivation for using Scala in preference to certain other languages I could mention…
  2. Manipulating data with BreezeBreeze is the standard Scala library for scientific and statistical computing. It’s pretty good, but documentation is rather lacking. This Chapter provides a good tutorial introduction to Breeze, which should be enough to get people going sufficiently to be able to make some sense of the available on-line documentation.
  3. Plotting with breeze-viz – Breeze has some support for plotting and visualisation of data. It’s somewhat limited when compared to what is available in R, but is fine for interactive exploratory analysis. However, the available on-line documentation for breeze-viz is almost non-existent. This Chapter is the best introduction to breeze-viz that I have seen.
  4. Parallel collections and futures – the Scala standard library has built-in support for parallel and concurrent programming based on functional programming concepts such as parallel (monadic) collections and Futures. Again, this Chapter provides an excellent introduction to these powerful concepts, allowing the reader to start developing parallel algorithms for multi-core hardware with minimal fuss.
  5. Scala and SQL through JDBC – this Chapter looks at connecting to databases using standard JVM mechanisms such as JDBC. However, it gradually introduces more functional ways of interfacing with databases using typeclasses, motivating:
  6. Slick – a functional interface for SQL – an introduction to the Slick library for a more Scala-esque way of database interfacing.
  7. Web APIs – the practicalities of talking to web APIs. eg. authenticated HTTP requests and parsing of JSON responses.
  8. Scala and MongoDB – working with a NoSQL database from Scala
  9. Concurrency with Akka – Akka is the canonical implementation of the actor model in Scala, for building large concurrent applications. It is the foundation on which Spark is built.
  10. Distributed batch processing with Spark – a tutorial introduction to Apache Spark. Spark is a big data analytics framework built on top of Scala and Akka. It is arguably the best available framework for big data analytics on computing clusters in the cloud, and hence there is a lot of interest in it. Indeed, Spark is driving some of the interest in Scala.
  11. Spark SQL and DataFrames – interfacing with databases using Spark, and more importantly, an introduction to Spark’s DataFrame abstraction, which is now fundamental to developing machine learning pipelines in Spark.
  12. Distributed machine learning with MLLib – MLLib is the machine learning library for Spark. It is worth emphasising that unlike many early books on Spark, this chapter covers the newer DataFrame-based pipeline API, in addition to the original RDD-based API. Together, Chapters 10, 11 and 12 provide a pretty good tutorial introduction to Spark. After working through these, it should be easy to engage with the official on-line Spark documentation.
  13. Web APIs with Play – is concerned with developing a web API at the end of a data science pipeline.
  14. Visualisation with D3 and the Play framework – is concerned with integrating visualisation into a data science web application.

Summary

This book provides a good tutorial introduction to a large number of topics relevant to statisticians and data scientists interested in developing data science applications using Scala. After working through this book, readers should be well-placed to augment their knowledge with readily searchable on-line documentation.

In a follow-up post I will give a quick overview of some other books relevant to getting started with Scala for statistical computing and data science.