Bayesian inference for a logistic regression model (Part 2)

Part 2: The log posterior


This is the second part in a series of posts on MCMC-based Bayesian inference for a logistic regression model. If you are new to this series, please go back to Part 1.

In the previous post we looked at the basic modelling concepts, and how to fit the model using a variety of PPLs. In this post we will prepare for doing MCMC by considering the problem of computing the unnormalised log posterior for the model. We will then see how this posterior can be implemented in several different languages and libraries.


Basic structure

In Bayesian inference the posterior distribution is just the conditional distribution of the model parameters given the data, and therefore proportional to the joint distribution of the model and data. We often write this as

\displaystyle \pi(\theta|y) \propto \pi(\theta,y) = \pi(\theta)\pi(y|\theta).

Taking logs we have

\displaystyle \log \pi(\theta, y) = \log \pi(\theta) + \log \pi(y|\theta).

So (up to an additive constant) the log posterior is just the sum of the log prior and log likelihood. There are many good (numerical) reasons why we try to work exclusively with the log posterior and try to avoid ever evaluating the raw posterior density.

For our example logistic regression model, the parameter vector \theta is just the vector of regression coefficients, \beta. We assumed independent mean zero normal priors for the components of this vector, so the log prior is just the sum of logs of normal densities. Many scientific libraries will have built-in functions for returning the log-pdf of standard distributions, but if an explicit form is required, the log of the density of a N(0,\sigma^2) at x is just

\displaystyle -\log(2\pi)/2 - \log|\sigma| - x^2/(2\sigma^2),

and the initial constant term normalising the density can often be dropped.

Log-likelihood (first attempt)

Information from the data comes into the log posterior via the log-likelihood. The typical way to derive the likelihood for problems of this type is to assume the usual binary encoding of the data (success 1, failure 0). Then, for a Bernoulli observation with probability p_i,\ i=1,\ldots,n, the likelihood associated with observation y_i is

\displaystyle f(y_i|p_i) = \left[ \hphantom{1-}p_i \quad :\ y_i=1 \atop 1-p_i \quad :\ y_i=0 \right. \quad = \quad p_i^{y_i}(1-p_i)^{1-y_i}.

Taking logs and then switching to parameter \eta_i=\text{logit}(p_i) we have

\displaystyle \log f(y_i|\eta_i) = y_i\eta_i - \log(1+e^{\eta_i}),

and summing over n observations gives the log likelihood

\displaystyle \log\pi(y|\eta) \equiv \ell(\eta;y) = y\cdot \eta - \mathbf{1}\cdot\log(\mathbf{1}+\exp{\eta}).

In the context of logistic regression, \eta is the linear predictor, so \eta=X\beta, giving

\displaystyle \ell(\beta;y) = y^\textsf{T}X\beta - \mathbf{1}^\textsf{T}\log(\mathbf{1}+\exp[X\beta]).

This is a perfectly good way of expressing the log-likelihood, and we will come back to it later when we want the gradient of the log-likelihood, but it turns out that there is a similar-but-different way of deriving it that results in an expression that is equivalent but slightly cheaper to evaluate.

Log-likelihood (second attempt)

For our second attempt, we will assume that the data is coded in a different way. Instead of the usual binary encoding, we will assume that the observation \tilde y_i is 1 for success and -1 for failure. This isn’t really a problem, since the two encodings are related by \tilde y_i = 2y_i-1. This new encoding is convenient in the context of a logit parameterisation since then

\displaystyle f(y_i|\eta_i) = \left[ p_i \ :\ \tilde y_i=1\atop 1-p_i\ :\ \tilde y_i=-1 \right. \ = \ \left[ (1+e^{-\eta_i})^{-1} \ :\ \tilde y_i=1\atop (1+e^{\eta_i})^{-1} \ :\ \tilde y_i=-1 \right. \ = \ (1+e^{-\tilde y_i\eta_i})^{-1} ,

and hence

\displaystyle \log f(y_i|\eta_i) = -\log(1+e^{-\tilde y_i\eta_i}).

Summing over observations gives

\displaystyle \ell(\eta;\tilde y) = -\mathbf{1}\cdot \log(\mathbf{1}+\exp[-\tilde y\circ \eta]),

where \circ denotes the Hadamard product. Substituting \eta=X\beta gives the log-likelihood

\displaystyle \ell(\beta;\tilde y) = -\mathbf{1}^\textsf{T} \log(\mathbf{1}+\exp[-\tilde y\circ X\beta]).

This likelihood is a bit cheaper to evaluate that the one previously derived. If we prefer to write in terms of the original data encoding, we can obviously do so as

\displaystyle \ell(\beta; y) = -\mathbf{1}^\textsf{T} \log(\mathbf{1}+\exp[-(2y-\mathbf{1})\circ (X\beta)]),

and in practice, it is this version that is typically used. To be clear, as an algebraic function of \beta and y the two functions are different. But they coincide for binary vectors y which is all that matters.



In R we can create functions for evaluating the log-likelihood, log-prior and log-posterior as follows (assuming that X and y are in scope).

ll = function(beta)
    sum(-log(1 + exp(-(2*y - 1)*(X %*% beta))))

lprior = function(beta)
    dnorm(beta[1], 0, 10, log=TRUE) + sum(dnorm(beta[-1], 0, 1, log=TRUE))

lpost = function(beta) ll(beta) + lprior(beta)


In Python (with NumPy and SciPy) we can define equivalent functions with

def ll(beta):
    return np.sum(-np.log(1 + np.exp(-(2*y - 1)*(

def lprior(beta):
    return (sp.stats.norm.logpdf(beta[0], loc=0, scale=10) + 
            np.sum(sp.stats.norm.logpdf(beta[range(1,p)], loc=0, scale=1)))

def lpost(beta):
    return ll(beta) + lprior(beta)


Python, like R, is a dynamic language, and relatively slow for MCMC algorithms. JAX is a tensor computation framework for Python that embeds a pure functional differentiable array processing language inside Python. JAX can JIT-compile high-performance code for both CPU and GPU, and has good support for parallelism. It is rapidly becoming the preferred way to develop high-performance sampling algorithms within the Python ecosystem. We can encode our log-posterior in JAX as follows.

def ll(beta):
    return jnp.sum(-jnp.log(1 + jnp.exp(-(2*y - 1)*, beta))))

def lprior(beta):
    return (jsp.stats.norm.logpdf(beta[0], loc=0, scale=10) + 
            jnp.sum(jsp.stats.norm.logpdf(beta[jnp.array(range(1,p))], loc=0, scale=1)))

def lpost(beta):
    return ll(beta) + lprior(beta)


JAX is a pure functional programming language embedded in Python. Pure functional programming languages are intrinsically more scalable and compositional than imperative languages such as R and Python, and are much better suited to exploit concurrency and parallelism. I’ve given a bunch of talks about this recently, so if you are interested in this, perhaps start with the materials for my Laplace’s Demon talk. Scala and Haskell are arguably the current best popular general purpose functional programming languages, so it is possibly interesting to consider the use of these languages for the development of scalable statistical inference codes. Since both languages are statically typed compiled functional languages with powerful type systems, they can be highly performant. However, neither is optimised for numerical (tensor) computation, so you should not expect that they will have performance comparable with optimised tensor computation frameworks such as JAX. We can encode our log-posterior in Scala (with Breeze) as follows:

  def ll(beta: DVD): Double =
      sum(-log(ones + exp(-1.0*(2.0*y - ones)*:*(X * beta))))

  def lprior(beta: DVD): Double =
    Gaussian(0,10).logPdf(beta(0)) + 
      sum(beta(1 until p).map(Gaussian(0,1).logPdf(_)))

  def lpost(beta: DVD): Double = ll(beta) + lprior(beta)


Apache Spark is a Scala library for distributed "big data" processing on clusters of machines. Despite fundamental differences, there is a sense in which Spark for Scala is a bit analogous to JAX for Python: both Spark and JAX are concerned with scalability, but they are targeting rather different aspects of scalability: JAX is concerned with getting regular sized data processing algorithms to run very fast (on GPUs), whereas Spark is concerned with running huge data processing tasks quickly by distributing work over clusters of machines. Despite obvious differences, the fundamental pure functional computational model adopted by both systems is interestingly similar: both systems are based on lazy transformations of immutable data structures using pure functions. This is a fundamental pattern for scalable data processing transcending any particular language, library or framework. We can encode our log posterior in Spark as follows.

    def ll(beta: DVD): Double ={row =>
        val y = row.getAs[Double](0)
        val x = BDV.vertcat(BDV(1.0),toBDV(row.getAs[DenseVector](1)))
        -math.log(1.0 + math.exp(-1.0*(2.0*y-1.0)*(}.reduce(_+_)
    def lprior(beta: DVD): Double =
      Gaussian(0,10).logPdf(beta(0)) +
        sum(beta(1 until p).map(Gaussian(0,1).logPdf(_)))
    def lpost(beta: DVD): Double =
      ll(beta) + lprior(beta)


Haskell is an old, lazy pure functional programming language with an advanced type system, and remains the preferred language for the majority of functional programming language researchers. Hmatrix is the standard high performance numerical linear algebra library for Haskell, so we can use it to encode our log-posterior as follows.

ll :: Matrix Double -> Vector Double -> Vector Double -> Double
ll x y b = (negate) (vsum (cmap log (
                              (scalar 1) + (cmap exp (cmap (negate) (
                                                         (((scalar 2) * y) - (scalar 1)) * (x #> b)

pscale :: [Double] -- prior standard deviations
pscale = [10.0, 1, 1, 1, 1, 1, 1, 1]
lprior :: Vector Double -> Double
lprior b = sum $  (\x -> logDensity (normalDistr 0.0 (snd x)) (fst x)) <$> (zip (toList b) pscale)
lpost :: Matrix Double -> Vector Double -> Vector Double -> Double
lpost x y b = (ll x y b) + (lprior b)

Again, a reminder that, here and elsewhere, there are various optimisations could be done that I’m not bothering with. This is all just proof-of-concept code.


JAX proves that a pure functional DSL for tensor computation can be extremely powerful and useful. But embedding such a language in a dynamic imperative language like Python has a number of drawbacks. Dex is an experimental statically typed stand-alone DSL for differentiable array and tensor programming that attempts to combine some of the correctness and composability benefits of powerful statically typed functional languages like Scala and Haskell with the performance benefits of tensor computation systems like JAX. It is currently rather early its development, but seems very interesting, and is already quite useable. We can encode our log-posterior in Dex as follows.

def ll (b: (Fin 8)=>Float) : Float =
  neg $  sum (log (map (\ x. (exp x) + 1) ((map (\ yi. 1 - 2*yi) y)*(x **. b))))

pscale = [10.0, 1, 1, 1, 1, 1, 1, 1] -- prior SDs
prscale = map (\ x. 1.0/x) pscale

def lprior (b: (Fin 8)=>Float) : Float =
  bs = b*prscale
  neg $  sum ((log pscale) + (0.5 .* (bs*bs)))

def lpost (b: (Fin 8)=>Float) : Float =
  (ll b) + (lprior b)

Next steps

Now that we have a way of evaluating the log posterior, we can think about constructing Markov chains having the posterior as their equilibrium distribution. In the next post we will look at one of the simplest ways of doing this: the Metropolis algorithm.

Complete runnable scripts are available from this public github repo.


Comonads for scientific and statistical computing in Scala


In a previous post I’ve given a brief introduction to monads in Scala, aimed at people interested in scientific and statistical computing. Monads are a concept from category theory which turn out to be exceptionally useful for solving many problems in functional programming. But most categorical concepts have a dual, usually prefixed with “co”, so the dual of a monad is the comonad. Comonads turn out to be especially useful for formulating algorithms from scientific and statistical computing in an elegant way. In this post I’ll illustrate their use in signal processing, image processing, numerical integration of PDEs, and Gibbs sampling (of an Ising model). Comonads enable the extension of a local computation to a global computation, and this pattern crops up all over the place in statistical computing.

Monads and comonads

Simplifying massively, from the viewpoint of a Scala programmer, a monad is a mappable (functor) type class augmented with the methods pure and flatMap:

trait Monad[M[_]] extends Functor[M] {
  def pure[T](v: T): M[T]
  def flatMap[T,S](v: M[T])(f: T => M[S]): M[S]

In category theory, the dual of a concept is typically obtained by “reversing the arrows”. Here that means reversing the direction of the methods pure and flatMap to get extract and coflatMap, respectively.

trait Comonad[W[_]] extends Functor[W] {
  def extract[T](v: W[T]): T
  def coflatMap[T,S](v: W[T])(f: W[T] => S): W[S]

So, while pure allows you to wrap plain values in a monad, extract allows you to get a value out of a comonad. So you can always get a value out of a comonad (unlike a monad). Similarly, while flatMap allows you to transform a monad using a function returning a monad, coflatMap allows you to transform a comonad using a function which collapses a comonad to a single value. It is coflatMap (sometimes called extend) which can extend a local computation (producing a single value) to the entire comonad. We’ll look at how that works in the context of some familiar examples.

Applying a linear filter to a data stream

One of the simplest examples of a comonad is an infinite stream of data. I’ve discussed streams in a previous post. By focusing on infinite streams we know the stream will never be empty, so there will always be a value that we can extract. Which value does extract give? For a Stream encoded as some kind of lazy list, the only value we actually know is the value at the head of the stream, with subsequent values to be lazily computed as required. So the head of the list is the only reasonable value for extract to return.

Understanding coflatMap is a bit more tricky, but it is coflatMap that provides us with the power to apply a non-trivial statistical computation to the stream. The input is a function which transforms a stream into a value. In our example, that will be a function which computes a weighted average of the first few values and returns that weighted average as the result. But the return type of coflatMap must be a stream of such computations. Following the types, a few minutes thought reveals that the only reasonable thing to do is to return the stream formed by applying the weighted average function to all sub-streams, recursively. So, for a Stream s (of type Stream[T]) and an input function f: W[T] => S, we form a stream whose head is f(s) and whose tail is coflatMap(f) applied to s.tail. Again, since we are working with an infinite stream, we don’t have to worry about whether or not the tail is empty. This gives us our comonadic Stream, and it is exactly what we need for applying a linear filter to the data stream.

In Scala, Cats is a library providing type classes from Category theory, and instances of those type classes for parametrised types in the standard library. In particular, it provides us with comonadic functionality for the standard Scala Stream. Let’s start by defining a stream corresponding to the logistic map.

import cats._
import cats.implicits._

val lam = 3.7
def s = Stream.iterate(0.5)(x => lam*x*(1-x))
// res0: List[Double] = List(0.5, 0.925, 0.25668749999999985,
//  0.7059564011718747, 0.7680532550204203, 0.6591455741499428, ...

Let us now suppose that we want to apply a linear filter to this stream, in order to smooth the values. The idea behind using comonads is that you figure out how to generate one desired value, and let coflatMap take care of applying the same logic to the rest of the structure. So here, we need a function to generate the first filtered value (since extract is focused on the head of the stream). A simple first attempt a function to do this might look like the following.

  def linearFilterS(weights: Stream[Double])(s: Stream[Double]): Double =
    (weights, s).parMapN(_*_).sum

This aligns each weight in parallel with a corresponding value from the stream, and combines them using multiplication. The resulting (hopefully finite length) stream is then summed (with addition). We can test this with

// res1: Double = 0.651671875

and let coflatMap extend this computation to the rest of the stream with something like:

// res2: List[Double] = List(0.651671875, 0.5360828502929686, ...

This is all completely fine, but our linearFilterS function is specific to the Stream comonad, despite the fact that all we’ve used about it in the function is that it is a parallelly composable and foldable. We can make this much more generic as follows:

  def linearFilter[F[_]: Foldable, G[_]](
    weights: F[Double], s: F[Double]
  )(implicit ev: NonEmptyParallel[F, G]): Double =
    (weights, s).parMapN(_*_).fold

This uses some fairly advanced Scala concepts which I don’t want to get into right now (I should also acknowledge that I had trouble getting the syntax right for this, and got help from Fabio Labella (@SystemFw) on the Cats gitter channel). But this version is more generic, and can be used to linearly filter other data structures than Stream. We can use this for regular Streams as follows:

s.coflatMap(s => linearFilter(Stream(0.25,0.5,0.25),s))
// res3: scala.collection.immutable.Stream[Double] = Stream(0.651671875, ?)

But we can apply this new filter to other collections. This could be other, more sophisticated, streams such as provided by FS2, Monix or Akka streams. But it could also be a non-stream collection, such as List:

val sl = s.take(10).toList
sl.coflatMap(sl => linearFilter(List(0.25,0.5,0.25),sl))
// res4: List[Double] = List(0.651671875, 0.5360828502929686, ...

Assuming that we have the Breeze scientific library available, we can plot the raw and smoothed trajectories.

def myFilter(s: Stream[Double]): Double =
  linearFilter(Stream(0.25, 0.5, 0.25),s)
val n = 500
import breeze.plot._
import breeze.linalg._
val fig = Figure(s"The (smoothed) logistic map (lambda=$lam)")
val p0 = fig.subplot(3,1,0)
p0 += plot(linspace(1,n,n),s.take(n))
p0.ylim = (0.0,1.0)
p0.title = s"The logistic map (lambda=$lam)"
val p1 = fig.subplot(3,1,1)
p1 += plot(linspace(1,n,n),s.coflatMap(myFilter).take(n))
p1.ylim = (0.0,1.0)
p1.title = "Smoothed by a simple linear filter"
val p2 = fig.subplot(3,1,2)
p2 += plot(linspace(1,n,n),s.coflatMap(myFilter).coflatMap(myFilter).coflatMap(myFilter).coflatMap(myFilter).coflatMap(myFilter).take(n))
p2.ylim = (0.0,1.0)
p2.title = "Smoothed with 5 applications of the linear filter"

Image processing and the heat equation

Streaming data is in no way the only context in which a comonadic approach facilitates an elegant approach to scientific and statistical computing. Comonads crop up anywhere where we want to extend a computation that is local to a small part of a data structure to the full data structure. Another commonly cited area of application of comonadic approaches is image processing (I should acknowledge that this section of the post is very much influenced by a blog post on comonadic image processing in Haskell). However, the kinds of operations used in image processing are in many cases very similar to the operations used in finite difference approaches to numerical integration of partial differential equations (PDEs) such as the heat equation, so in this section I will blur (sic) the distinction between the two, and numerically integrate the 2D heat equation in order to Gaussian blur a noisy image.

First we need a simple image type which can have pixels of arbitrary type T (this is very important – all functors must be fully type polymorphic).

  import scala.collection.parallel.immutable.ParVector
  case class Image[T](w: Int, h: Int, data: ParVector[T]) {
    def apply(x: Int, y: Int): T = data(x*h+y)
    def map[S](f: T => S): Image[S] = Image(w, h, data map f)
    def updated(x: Int, y: Int, value: T): Image[T] =

Here I’ve chosen to back the image with a parallel immutable vector. This wasn’t necessary, but since this type has a map operation which automatically parallelises over multiple cores, any map operations applied to the image will be automatically parallelised. This will ultimately lead to all of our statistical computations being automatically parallelised without us having to think about it.

As it stands, this image isn’t comonadic, since it doesn’t implement extract or coflatMap. Unlike the case of Stream, there isn’t really a uniquely privileged pixel, so it’s not clear what extract should return. For many data structures of this type, we make them comonadic by adding a “cursor” pointing to a “current” element of interest, and use this as the focus for computations applied with coflatMap. This is simplest to explain by example. We can define our “pointed” image type as follows:

  case class PImage[T](x: Int, y: Int, image: Image[T]) {
    def extract: T = image(x, y)
    def map[S](f: T => S): PImage[S] = PImage(x, y, image map f)
    def coflatMap[S](f: PImage[T] => S): PImage[S] = PImage(
      x, y, Image(image.w, image.h,
      (0 until (image.w * image.h)) => {
        val xx = i / image.h
        val yy = i % image.h
        f(PImage(xx, yy, image))

There is missing a closing brace, as I’m not quite finished. Here x and y represent the location of our cursor, so extract returns the value of the pixel indexed by our cursor. Similarly, coflatMap forms an image where the value of the image at each location is the result of applying the function f to the image which had the cursor set to that location. Clearly f should use the cursor in some way, otherwise the image will have the same value at every pixel location. Note that map and coflatMap operations will be automatically parallelised. The intuitive idea behind coflatMap is that it extends local computations. For the stream example, the local computation was a linear combination of nearby values. Similarly, in image analysis problems, we often want to apply a linear filter to nearby pixels. We can get at the pixel at the cursor location using extract, but we probably also want to be able to move the cursor around to nearby locations. We can do that by adding some appropriate methods to complete the class definition.

    def up: PImage[T] = {
      val py = y-1
      val ny = if (py >= 0) py else (py + image.h)
    def down: PImage[T] = {
      val py = y+1
      val ny = if (py < image.h) py else (py - image.h)
    def left: PImage[T] = {
      val px = x-1
      val nx = if (px >= 0) px else (px + image.w)
    def right: PImage[T] = {
      val px = x+1
      val nx = if (px < image.w) px else (px - image.w)

Here each method returns a new pointed image with the cursor shifted by one pixel in the appropriate direction. Note that I’ve used periodic boundary conditions here, which often makes sense for numerical integration of PDEs, but makes less sense for real image analysis problems. Note that we have embedded all “indexing” issues inside the definition of our classes. Now that we have it, none of the statistical algorithms that we develop will involve any explicit indexing. This makes it much less likely to develop algorithms containing bugs corresponding to “off-by-one” or flipped axis errors.

This class is now fine for our requirements. But if we wanted Cats to understand that this structure is really a comonad (perhaps because we wanted to use derived methods, such as coflatten), we would need to provide evidence for this. The details aren’t especially important for this post, but we can do it simply as follows:

  implicit val pimageComonad = new Comonad[PImage] {
    def extract[A](wa: PImage[A]) = wa.extract
    def coflatMap[A,B](wa: PImage[A])(f: PImage[A] => B): PImage[B] =
    def map[A,B](wa: PImage[A])(f: A => B): PImage[B] =

It’s handy to have some functions for converting Breeze dense matrices back and forth with our image class.

  import breeze.linalg.{Vector => BVec, _}
  def BDM2I[T](m: DenseMatrix[T]): Image[T] =
    Image(m.cols, m.rows,
  def I2BDM(im: Image[Double]): DenseMatrix[Double] =
    new DenseMatrix(im.h,im.w,

Now we are ready to see how to use this in practice. Let’s start by defining a very simple linear filter.

def fil(pi: PImage[Double]): Double = (2*pi.extract+

This simple filter can be used to “smooth” or “blur” an image. However, from a more sophisticated viewpoint, exactly this type of filter can be used to represent one time step of a numerical method for time integration of the 2D heat equation. Now we can simulate a noisy image and apply our filter to it using coflatMap:

import breeze.stats.distributions.Gaussian
val bdm = DenseMatrix.tabulate(200,250){case (i,j) => math.cos(
  0.1*math.sqrt((i*i+j*j))) + Gaussian(0.0,2.0).draw}
val pim0 = PImage(0,0,BDM2I(bdm))
def pims = Stream.iterate(pim0)(_.coflatMap(fil))

Note that here, rather than just applying the filter once, I’ve generated an infinite stream of pointed images, each one representing an additional application of the linear filter. Thus the sequence represents the time solution of the heat equation with initial condition corresponding to our simulated noisy image.

We can render the first few frames to check that it seems to be working.

import breeze.plot._
val fig = Figure("Diffusing a noisy image")
pims.take(25).zipWithIndex.foreach{case (pim,i) => {
  val p = fig.subplot(5,5,i)
  p += image(I2BDM(pim.image))

Note that the numerical integration is carried out in parallel on all available cores automatically. Other image filters can be applied, and other (parabolic) PDEs can be numerically integrated in an essentially similar way.

Gibbs sampling the Ising model

Another place where the concept of extending a local computation to a global computation crops up is in the context of Gibbs sampling a high-dimensional probability distribution by cycling through the sampling of each variable in turn from its full-conditional distribution. I’ll illustrate this here using the Ising model, so that I can reuse the pointed image class from above, but the principles apply to any Gibbs sampling problem. In particular, the Ising model that we consider has a conditional independence structure corresponding to a graph of a square lattice. As above, we will use the comonadic structure of the square lattice to construct a Gibbs sampler. However, we can construct a Gibbs sampler for arbitrary graphical models in an essentially identical way by using a graph comonad.

Let’s begin by simulating a random image containing +/-1s:

import breeze.stats.distributions.{Binomial,Bernoulli}
val beta = 0.4
val bdm = DenseMatrix.tabulate(500,600){
  case (i,j) => (new Binomial(1,0.2)).draw
}.map(_*2 - 1) // random matrix of +/-1s
val pim0 = PImage(0,0,BDM2I(bdm))

We can use this to initialise our Gibbs sampler. We now need a Gibbs kernel representing the update of each pixel.

def gibbsKernel(pi: PImage[Int]): Int = {
   val sum = pi.up.extract+pi.down.extract+pi.left.extract+pi.right.extract
   val p1 = math.exp(beta*sum)
   val p2 = math.exp(-beta*sum)
   val probplus = p1/(p1+p2)
   if (new Bernoulli(probplus).draw) 1 else -1

So far so good, but there a couple of issues that we need to consider before we plough ahead and start coflatMapping. The first is that pure functional programmers will object to the fact that this function is not pure. It is a stochastic function which has the side-effect of mutating the random number state. I’m just going to duck that issue here, as I’ve previously discussed how to fix it using probability monads, and I don’t want it to distract us here.

However, there is a more fundamental problem here relating to parallel versus sequential application of Gibbs kernels. coflatMap is conceptually parallel (irrespective of how it is implemented) in that all computations used to build the new comonad are based solely on the information available in the starting comonad. OTOH, detailed balance of the Markov chain will only be preserved if the kernels for each pixel are applied sequentially. So if we coflatMap this kernel over the image we will break detailed balance. I should emphasise that this has nothing to do with the fact that I’ve implemented the pointed image using a parallel vector. Exactly the same issue would arise if we switched to backing the image with a regular (sequential) immutable Vector.

The trick here is to recognise that if we coloured alternate pixels black and white using a chequerboard pattern, then all of the black pixels are conditionally independent given the white pixels and vice-versa. Conditionally independent pixels can be updated by parallel application of a Gibbs kernel. So we just need separate kernels for updating odd and even pixels.

def oddKernel(pi: PImage[Int]): Int =
  if ((pi.x+pi.y) % 2 != 0) pi.extract else gibbsKernel(pi)
def evenKernel(pi: PImage[Int]): Int =
  if ((pi.x+pi.y) % 2 == 0) pi.extract else gibbsKernel(pi)

Each of these kernels can be coflatMapped over the image preserving detailed balance of the chain. So we can now construct an infinite stream of MCMC iterations as follows.

def pims = Stream.iterate(pim0)(_.coflatMap(oddKernel).

We can animate the first few iterations with:

import breeze.plot._
val fig = Figure("Ising model Gibbs sampler")
fig.width = 1000
fig.height = 800
pims.take(50).zipWithIndex.foreach{case (pim,i) => {
  print(s"$i ")
  val p = fig.subplot(1,1,0)
  p.title = s"Ising model: frame $i"
  p += image(I2BDM({_.toDouble}))

Here I have a movie showing the first 1000 iterations. Note that youtube seems to have over-compressed it, but you should get the basic idea.

Again, note that this MCMC sampler runs in parallel on all available cores, automatically. This issue of odd/even pixel updating emphasises another issue that crops up a lot in functional programming: very often, thinking about how to express an algorithm functionally leads to an algorithm which parallelises naturally. For general graphs, figuring out which groups of nodes can be updated in parallel is essentially the graph colouring problem. I’ve discussed this previously in relation to parallel MCMC in:

Wilkinson, D. J. (2005) Parallel Bayesian Computation, Chapter 16 in E. J. Kontoghiorghes (ed.) Handbook of Parallel Computing and Statistics, Marcel Dekker/CRC Press, 481-512.

Further reading

There are quite a few blog posts discussing comonads in the context of Haskell. In particular, the post on comonads for image analysis I mentioned previously, and this one on cellular automata. Bartosz’s post on comonads gives some connection back to the mathematical origins. Runar’s Scala comonad tutorial is the best source I know for comonads in Scala.

Full runnable code corresponding to this blog post is available from my blog repo.

scala-glm: Regression modelling in Scala


As discussed in the previous post, I’ve recently constructed and delivered a short course on statistical computing with Scala. Much of the course is concerned with writing statistical algorithms in Scala, typically making use of the scientific and numerical computing library, Breeze. Breeze has all of the essential tools necessary for building statistical algorithms, but doesn’t contain any higher level modelling functionality. As part of the course, I walked through how to build a small library for regression modelling on top of Breeze, including all of the usual regression diagnostics (such as standard errors, t-statistics, p-values, F-statistics, etc.). While preparing the course materials it occurred to me that it would be useful to package and document this code properly for general use. In advance of the course I packaged the code up into a bare-bones library, but since then I’ve fleshed it out, tidied it up and documented it properly, so it’s now ready for people to use.

The library covers PCA, linear regression modelling and simple one-parameter GLMs (including logistic and Poisson regression). The underlying algorithms are fairly efficient and numerically stable (eg. linear regression uses the QR decomposition of the model matrix, and the GLM fitting uses QR within each IRLS step), though they are optimised more for clarity than speed. The library also includes a few utility functions and procedures, including a pairs plot (scatter-plot matrix).

A linear regression example

Plenty of documentation is available from the scala-glm github repo which I won’t repeat here. But to give a rough idea of how things work, I’ll run through an interactive session for the linear regression example.

First, download a dataset from the UCI ML Repository to disk for subsequent analysis (caching the file on disk is good practice, as it avoids unnecessary load on the UCI server, and allows running the code off-line):

import scalaglm._
import breeze.linalg._

val url = ""
val fileName = "self-noise.csv"

// download the file to disk if it hasn't been already
val file = new
if (!file.exists) {
  val s = new
  val data =
  data.foreach(l => s.write(l.trim.
    split('\t').filter(_ != "").
    mkString("", ",", "\n")))

Once we have a CSV file on disk, we can load it up and look at it.

val mat = csvread(new
// mat: breeze.linalg.DenseMatrix[Double] =
// 800.0    0.0  0.3048  71.3  0.00266337  126.201
// 1000.0   0.0  0.3048  71.3  0.00266337  125.201
// 1250.0   0.0  0.3048  71.3  0.00266337  125.951
// ...
println("Dim: " + mat.rows + " " + mat.cols)
// Dim: 1503 6
val figp = Utils.pairs(mat, List("Freq", "Angle", "Chord", "Velo", "Thick", "Sound"))
// figp: breeze.plot.Figure = breeze.plot.Figure@37718125

We can then regress the response in the final column on the other variables.

val y = mat(::, 5) // response is the final column
// y: DenseVector[Double] = DenseVector(126.201, 125.201, ...
val X = mat(::, 0 to 4)
// X: breeze.linalg.DenseMatrix[Double] =
// 800.0    0.0  0.3048  71.3  0.00266337
// 1000.0   0.0  0.3048  71.3  0.00266337
// 1250.0   0.0  0.3048  71.3  0.00266337
// ...
val mod = Lm(y, X, List("Freq", "Angle", "Chord", "Velo", "Thick"))
// mod: scalaglm.Lm =
// Lm(DenseVector(126.201, 125.201, ...
// Estimate	 S.E.	 t-stat	p-value		Variable
// ---------------------------------------------------------
// 132.8338	 0.545	243.866	0.0000 *	(Intercept)
//  -0.0013	 0.000	-30.452	0.0000 *	Freq
//  -0.4219	 0.039	-10.847	0.0000 *	Angle
// -35.6880	 1.630	-21.889	0.0000 *	Chord
//   0.0999	 0.008	12.279	0.0000 *	Velo
// -147.3005	15.015	-9.810	0.0000 *	Thick
// Residual standard error:   4.8089 on 1497 degrees of freedom
// Multiple R-squared: 0.5157, Adjusted R-squared: 0.5141
// F-statistic: 318.8243 on 5 and 1497 DF, p-value: 0.00000
val fig = mod.plots
// fig: breeze.plot.Figure = breeze.plot.Figure@60d7ebb0

There is a .predict method for generating point predictions (and standard errors) given a new model matrix, and fitting GLMs is very similar – these things are covered in the quickstart guide for the library.


scala-glm is a small Scala library built on top of the Breeze numerical library which enables simple and convenient regression modelling in Scala. It is reasonably well documented and usable in its current form, but I intend to gradually add additional features according to demand as time permits.