The smfsb R package

Introduction

In the previous post I gave a brief introduction to the third edition of my textbook, Stochastic modelling for systems biology. The algorithms described in the book are illustrated by implementations in R. These implementations are collected together in an R package on CRAN called smfsb. This post will provide a brief introduction to the package and its capabilities.

Installation

The package is on CRAN – see the CRAN package page for details. So the simplest way to install it is to enter

install.packages("smfsb")

at the R command prompt. This will install the latest version that is on CRAN. Once installed, the package can be loaded with

library(smfsb)

The package is well-documented, so further information can be obtained with the usual R mechanisms, such as

vignette(package="smfsb")
vignette("smfsb")
help(package="smfsb")
?StepGillespie
example(StepCLE1D)

The version of the package on CRAN is almost certainly what you want. However, the package is developed on R-Forge – see the R-Forge project page for details. So the very latest version of the package can always be installed with

install.packages("smfsb", repos="http://R-Forge.R-project.org")

if you have a reason for wanting it.

A brief tutorial

The vignette gives a quick introduction the the library, which I don’t need to repeat verbatim here. If you are new to the package, I recommend working through that before continuing. Here I’ll concentrate on some of the new features associated with the third edition.

Simulating stochastic kinetic models

Much of the book is concerned with the simulation of stochastic kinetic models using exact and approximate algorithms. Although the primary focus of the text is the application to modelling of intra-cellular processes, the methods are also appropriate for population modelling of ecological and epidemic processes. For example, we can start by simulating a simple susceptible-infectious-recovered (SIR) disease epidemic model.

set.seed(2)
data(spnModels)

stepSIR = StepGillespie(SIR)
plot(simTs(SIR$M, 0, 8, 0.05, stepSIR),
  main="Exact simulation of the SIR model")

Exact simulation of the SIR epidemic model
The focus of the text is stochastic simulation of discrete models, so that is the obvious place to start. But there is also support for continuous deterministic simulation.

plot(simTs(SIR$M, 0, 8, 0.05, StepEulerSPN(SIR)),
  main="Euler simulation of the SIR model")

Euler simulation of the SIR model
My favourite toy population dynamics model is the Lotka-Volterra (LV) model, so I tend to use this frequently as a running example throughout the book. We can simulate this (exactly) as follows.

stepLV = StepGillespie(LV)
plot(simTs(LV$M, 0, 30, 0.2, stepLV),
  main="Exact simulation of the LV model")

Exact simulation of the Lotka-Volterra model

Stochastic reaction-diffusion modelling

The first two editions of the book were almost exclusively concerned with well-mixed systems, where spatial effects are ignorable. One of the main new features of the third edition is the inclusion of a new chapter on spatially extended systems. The focus is on models related to the reaction diffusion master equation (RDME) formulation, rather than individual particle-based simulations. For these models, space is typically divided into a regular grid of voxels, with reactions taking place as normal within each voxel, and additional reaction events included, corresponding to the diffusion of particles to adjacent voxels. So to specify such models, we just need an initial condition, a reaction model, and diffusion coefficients (one for each reacting species). So, we can carry out exact simulation of an RDME model for a 1D spatial domain as follows.

N=20; T=30
x0=matrix(0, nrow=2, ncol=N)
rownames(x0) = c("x1", "x2")
x0[,round(N/2)] = LV$M
stepLV1D = StepGillespie1D(LV, c(0.6, 0.6))
xx = simTs1D(x0, 0, T, 0.2, stepLV1D, verb=TRUE)
image(xx[1,,], main="Prey", xlab="Space", ylab="Time")

Discrete 1D simulation of the LV model

image(xx[2,,], main="Predator", xlab="Space", ylab="Time")

Discrete 1D simulation of the LV model
Exact simulation of discrete stochastic reaction diffusion systems is very expensive (and the reference implementation provided in the package is very inefficient), so we will often use diffusion approximations based on the CLE.

stepLV1DC = StepCLE1D(LV, c(0.6, 0.6))
xx = simTs1D(x0, 0, T, 0.2, stepLV1D)
image(xx[1,,], main="Prey", xlab="Space", ylab="Time")

Spatial CLE simulation of the 1D LV model

image(xx[2,,], main="Predator", xlab="Space", ylab="Time")

Spatial CLE simulation of the 1D LV model
We can think of this algorithm as an explicit numerical integration of the obvious SPDE approximation to the exact model.

The package also includes support for simulation of 2D systems. Again, we can use the Spatial CLE to speed things up.

m=70; n=50; T=10
data(spnModels)
x0=array(0, c(2,m,n))
dimnames(x0)[[1]]=c("x1", "x2")
x0[,round(m/2),round(n/2)] = LV$M
stepLV2D = StepCLE2D(LV, c(0.6,0.6), dt=0.05)
xx = simTs2D(x0, 0, T, 0.5, stepLV2D)
N = dim(xx)[4]
image(xx[1,,,N],main="Prey",xlab="x",ylab="y")

Spatial CLE simulation of the 2D LV model

image(xx[2,,,N],main="Predator",xlab="x",ylab="y")

Spatial CLE simulation of the 2D LV model

Bayesian parameter inference

Although much of the book is concerned with the problem of forward simulation, the final chapters are concerned with the inverse problem of estimating model parameters, such as reaction rate constants, from data. A computational Bayesian approach is adopted, with the main emphasis being placed on “likelihood free” methods, which rely on forward simulation to avoid explicit computation of sample path likelihoods. The second edition included some rudimentary code for a likelihood free particle marginal Metropolis-Hastings (PMMH) particle Markov chain Monte Carlo (pMCMC) algorithm. The third edition includes a more complete and improved implementation, in addition to approximate inference algorithms based on approximate Bayesian computation (ABC).

The key function underpinning the PMMH approach is pfMLLik, which computes an estimate of marginal model log-likelihood using a (bootstrap) particle filter. There is a new implementation of this function with the third edition. There is also a generic implementation of the Metropolis-Hastings algorithm, metropolisHastings, which can be combined with pfMLLik to create a PMMH algorithm. PMMH algorithms are very slow, but a full demo of how to use these functions for parameter inference is included in the package and can be run with

demo(PMCMC)

Simple rejection-based ABC methods are facilitated by the (very simple) function abcRun, which just samples from a prior and then carries out independent simulations in parallel before computing summary statistics. A simple illustration of the use of the function is given below.

data(LVdata)
rprior <- function() { exp(c(runif(1, -3, 3),runif(1,-8,-2),runif(1,-4,2))) }
rmodel <- function(th) { simTs(c(50,100), 0, 30, 2, stepLVc, th) }
sumStats <- identity
ssd = sumStats(LVperfect)
distance <- function(s) {
    diff = s - ssd
    sqrt(sum(diff*diff))
}
rdist <- function(th) { distance(sumStats(rmodel(th))) }
out = abcRun(10000, rprior, rdist)
q=quantile(out$dist, c(0.01, 0.05, 0.1))
print(q)
##       1%       5%      10% 
## 772.5546 845.8879 881.0573
accepted = out$param[out$dist < q[1],]
print(summary(accepted))
##        V1                V2                  V3         
##  Min.   :0.06498   Min.   :0.0004467   Min.   :0.01887  
##  1st Qu.:0.16159   1st Qu.:0.0012598   1st Qu.:0.04122  
##  Median :0.35750   Median :0.0023488   Median :0.14664  
##  Mean   :0.68565   Mean   :0.0046887   Mean   :0.36726  
##  3rd Qu.:0.86708   3rd Qu.:0.0057264   3rd Qu.:0.36870  
##  Max.   :4.76773   Max.   :0.0309364   Max.   :3.79220
print(summary(log(accepted)))
##        V1                V2               V3         
##  Min.   :-2.7337   Min.   :-7.714   Min.   :-3.9702  
##  1st Qu.:-1.8228   1st Qu.:-6.677   1st Qu.:-3.1888  
##  Median :-1.0286   Median :-6.054   Median :-1.9198  
##  Mean   :-0.8906   Mean   :-5.877   Mean   :-1.9649  
##  3rd Qu.:-0.1430   3rd Qu.:-5.163   3rd Qu.:-0.9978  
##  Max.   : 1.5619   Max.   :-3.476   Max.   : 1.3329

Naive rejection-based ABC algorithms are notoriously inefficient, so the library also includes an implementation of a more efficient, sequential version of ABC, often known as ABC-SMC, in the function abcSmc. This function requires specification of a perturbation kernel to “noise up” the particles at each algorithm sweep. Again, the implementation is parallel, using the parallel package to run the required simulations in parallel on multiple cores. A simple illustration of use is given below.

rprior <- function() { c(runif(1, -3, 3), runif(1, -8, -2), runif(1, -4, 2)) }
dprior <- function(x, ...) { dunif(x[1], -3, 3, ...) + 
    dunif(x[2], -8, -2, ...) + dunif(x[3], -4, 2, ...) }
rmodel <- function(th) { simTs(c(50,100), 0, 30, 2, stepLVc, exp(th)) }
rperturb <- function(th){th + rnorm(3, 0, 0.5)}
dperturb <- function(thNew, thOld, ...){sum(dnorm(thNew, thOld, 0.5, ...))}
sumStats <- identity
ssd = sumStats(LVperfect)
distance <- function(s) {
    diff = s - ssd
    sqrt(sum(diff*diff))
}
rdist <- function(th) { distance(sumStats(rmodel(th))) }
out = abcSmc(5000, rprior, dprior, rdist, rperturb,
    dperturb, verb=TRUE, steps=6, factor=5)
## 6 5 4 3 2 1
print(summary(out))
##        V1                V2               V3        
##  Min.   :-2.9961   Min.   :-7.988   Min.   :-3.999  
##  1st Qu.:-1.9001   1st Qu.:-6.786   1st Qu.:-3.428  
##  Median :-1.2571   Median :-6.167   Median :-2.433  
##  Mean   :-1.0789   Mean   :-6.014   Mean   :-2.196  
##  3rd Qu.:-0.2682   3rd Qu.:-5.261   3rd Qu.:-1.161  
##  Max.   : 2.1128   Max.   :-2.925   Max.   : 1.706

We can then plot some results with

hist(out[,1],30,main="log(c1)")

ABC-SMC posterior for the LV model

hist(out[,2],30,main="log(c2)")

ABC-SMC posterior for the LV model

hist(out[,3],30,main="log(c3)")

ABC-SMC posterior for the LV model

Although the inference methods are illustrated in the book in the context of parameter inference for stochastic kinetic models, their implementation is generic, and can be used with any appropriate parameter inference problem.

The smfsbSBML package

smfsbSBML is another R package associated with the third edition of the book. This package is not on CRAN due to its dependency on a package not on CRAN, and hence is slightly less straightforward to install. Follow the available installation instructions to install the package. Once installed, you should be able to load the package with

library(smfsbSBML)

This package provides a function for reading in SBML files and parsing them into the simulatable stochastic Petri net (SPN) objects used by the main smfsb R package. Examples of suitable SBML models are included in the main smfsb GitHub repo. An appropriate SBML model can be read and parsed with a command like:

model = sbml2spn("mySbmlModel.xml")

The resulting value, model is an SPN object which can be passed in to simulation functions such as StepGillespie for constructing stochastic simulation algorithms.

Other software

In addition to the above R packages, I also have some Python scripts for converting between SBML and the SBML-shorthand notation I use in the book. See the SBML-shorthand page for further details.

Although R is a convenient language for teaching and learning about stochastic simulation, it isn’t ideal for serious research-level scientific computing or computational statistics. So for the third edition of the book I have also developed scala-smfsb, a library written in the Scala programming language, which re-implements all of the models and algorithms from the third edition of the book in Scala, a fast, efficient, strongly-typed, compiled, functional programming language. I’ll give an introduction to this library in a subsequent post, but in the meantime, it is already well documented, so see the scala-smfsb repo for further details, including information on installation, getting started, a tutorial, examples, API docs, etc.

Source

This blog post started out as an RMarkdown document, the source of which can be found here.

Advertisements

Stochastic Modelling for Systems Biology, third edition

The third edition of my textbook, Stochastic Modelling for Systems Biology has recently been published by Chapman & Hall/CRC Press. The book has ISBN-10 113854928-2 and ISBN-13 978-113854928-9. It can be ordered from CRC Press, Amazon.com, Amazon.co.uk and similar book sellers.

I was fairly happy with the way that the second edition, published in 2011, turned out, and so I haven’t substantially re-written any of the text for the third edition. Instead, I’ve concentrated on adding in new material and improving the associated on-line resources. Those on-line resources are all free and open source, and hence available to everyone, irrespective of whether you have a copy of the new edition. I’ll give an introduction to those resources below (and in subsequent posts). The new material can be briefly summarised as follows:

  • New chapter on spatially extended systems, covering the spatial Gillespie algorithm for reaction diffusion master equation (RDME) models in 1- and 2-d, the next subvolume method, spatial CLE, scaling issues, etc.
  • Significantly expanded chapter on inference for stochastic kinetic models from data, covering approximate methods of inference (ABC), including ABC-SMC. The material relating to particle MCMC has also been improved and extended.
  • Updated R package, including code relating to all of the new material
  • New R package for parsing SBML models into simulatable stochastic Petri net models
  • New software library, written in Scala, replicating most of the functionality of the R packages in a fast, compiled, strongly typed, functional language

New content

Although some minor edits and improvements have been made throughout the text, there are two substantial new additions to the text in this new edition. The first is an entirely new chapter on spatially extended systems. The first two editions of the text focused on the implications of discreteness and stochasticity in chemical reaction systems, but maintained the well-mixed assumption throughout. This is a reasonable first approach, since discreteness and stochasticity are most pronounced in very small volumes where diffusion should be rapid. In any case, even these non-spatial models have very interesting behaviour, and become computationally challenging very quickly for non-trivial reaction networks. However, we know that, in fact, the cell is a very crowded environment, and so even at small spatial scales, many interesting processes are diffusion limited. It therefore seems appropriate to dedicate one chapter (the new Chapter 9) to studying some of the implications of relaxing the well-mixed assumption. Entire books can be written on stochastic reaction-diffusion systems, so here only a brief introduction is provided, based mainly around models in the reaction-diffusion master equation (RDME) style. Exact stochastic simulation algorithms are discussed, and implementations provided in the 1- and 2-d cases, and an appropriate Langevin approximation is examined, the spatial CLE.

The second major addition is to the chapter on inference for stochastic kinetic models from data (now Chapter 11). The second edition of the book included a discussion of “likelihood free” Bayesian MCMC methods for inference, and provided a working implementation of likelihood free particle marginal Metropolis-Hastings (PMMH) for stochastic kinetic models. The third edition improves on that implementation, and discusses approximate Bayesian computation (ABC) as an alternative to MCMC for likelihood free inference. Implementation issues are discussed, and sequential ABC approaches are examined, concentrating in particular on the method known as ABC-SMC.

New software and on-line resources

Accompanying the text are new and improved on-line resources, all well-documented, free, and open source.

New website/GitHub repo

Information and materials relating to the previous editions were kept on my University website. All materials relating to this new edition are kept in a public GitHub repo: darrenjw/smfsb. This will be simpler to maintain, and will make it much easier for people to make copies of the material for use and studying off-line.

Updated R package(s)

Along with the second edition of the book I released an accompanying R package, “smfsb”, published on CRAN. This was a very popular feature, allowing anyone with R to trivially experiment with all of the models and algorithms discussed in the text. This R package has been updated, and a new version has been published to CRAN. The updates are all backwards-compatible with the version associated with the second edition of the text, so owners of that edition can still upgrade safely. I’ll give a proper introduction to the package, including the new features, in a subsequent post, but in the meantime, you can install/upgrade the package from a running R session with

install.packages("smfsb")

and then pop up a tutorial vignette with:

vignette("smfsb")

This should be enough to get you started.

In addition to the main R package, there is an additional R package for parsing SBML models into models that can be simulated within R. This package is not on CRAN, due to its dependency on a non-CRAN package. See the repo for further details.

There are also Python scripts available for converting SBML models to and from the shorthand SBML notation used in the text.

New Scala library

Another major new resource associated with the third edition of the text is a software library written in the Scala programming language. This library provides Scala implementations of all of the algorithms discussed in the book and implemented in the associated R packages. This then provides example implementations in a fast, efficient, compiled language, and is likely to be most useful for people wanting to use the methods in the book for research. Again, I’ll provide a tutorial introduction to this library in a subsequent post, but it is well-documented, with all necessary information needed to get started available at the scala-smfsb repo/website, including a step-by-step tutorial and some additional examples.

Bayesian hierarchical modelling with Rainier

Introduction

In the previous post I gave a brief introduction to Rainier, a new HMC-based probabilistic programming library/DSL for Scala. In that post I assumed that people were using the latest source version of the library. Since then, version 0.1.1 of the library has been released, so in this post I will demonstrate use of the released version of the software (using the binaries published to Sonatype), and will walk through a slightly more interesting example – a dynamic linear state space model with unknown static parameters. This is similar to, but slightly different from, the DLM example in the Rainier library. So to follow along with this post, all that is required is SBT.

An interactive session

First run SBT from an empty directory, and paste the following at the SBT prompt:

set libraryDependencies  += "com.stripe" %% "rainier-plot" % "0.1.1"
set scalaVersion := "2.12.4"
console

This should give a Scala REPL with appropriate dependencies (rainier-plot has all of the relevant transitive dependencies). We’ll begin with some imports, and then simulating some synthetic data from a dynamic linear state space model with an AR(1) latent state and Gaussian noise on the observations.

import com.stripe.rainier.compute._
import com.stripe.rainier.core._
import com.stripe.rainier.sampler._

implicit val rng = ScalaRNG(1)
val n = 60 // number of observations/time points
val mu = 3.0 // AR(1) mean
val a = 0.95 // auto-regressive parameter
val sig = 0.2 // AR(1) SD
val sigD = 3.0 // observational SD
val state = Stream.
  iterate(0.0)(x => mu + (x - mu) * a + sig * rng.standardNormal).
  take(n).toVector
val obs = state.map(_ + sigD * rng.standardNormal)

Now we have some synthetic data, let’s think about building a probabilistic program for this model. Start with a prior.

case class Static(mu: Real, a: Real, sig: Real, sigD: Real)
val prior = for {
  mu <- Normal(0, 10).param
  a <- Normal(1, 0.1).param
  sig <- Gamma(2,1).param
  sigD <- Gamma(2,2).param
  sp <- Normal(0, 50).param
} yield (Static(mu, a, sig, sigD), List(sp))

Note the use of a case class for wrapping the static parameters. Next, let’s define a function to add a state and associated observation to an existing model.

def addTimePoint(current: RandomVariable[(Static, List[Real])],
                     datum: Double) = for {
  tup <- current
  static = tup._1
  states = tup._2
  os = states.head
  ns <- Normal(((Real.one - static.a) * static.mu) + (static.a * os),
                 static.sig).param
  _ <- Normal(ns, static.sigD).fit(datum)
} yield (static, ns :: states)

Given this, we can generate the probabilistic program for our model as a fold over the data initialised with the prior.

val fullModel = obs.foldLeft(prior)(addTimePoint(_, _))

If we don’t want to keep samples for all of the variables, we can focus on the parameters of interest, wrapping the results in a Map for convenient sampling and plotting.

val model = for {
  tup <- fullModel
  static = tup._1
  states = tup._2
} yield
  Map("mu" -> static.mu,
  "a" -> static.a,
  "sig" -> static.sig,
  "sigD" -> static.sigD,
  "SP" -> states.reverse.head)

We can sample with

val out = model.sample(HMC(3), 100000, 10000 * 500, 500)

(this will take several minutes) and plot some diagnostics with

import com.cibo.evilplot.geometry.Extent
import com.stripe.rainier.plot.EvilTracePlot._

val truth = Map("mu" -> mu, "a" -> a, "sigD" -> sigD,
  "sig" -> sig, "SP" -> state(0))
render(traces(out, truth), "traceplots.png",
  Extent(1200, 1400))
render(pairs(out, truth), "pairs.png")

This generates the following diagnostic plots:

Everything looks good.

Summary

Rainier is a monadic embedded DSL for probabilistic programming in Scala. We can use standard functional combinators and for-expressions for building models to sample, and then run an efficient HMC algorithm on the resulting probability monad in order to obtain samples from the posterior distribution of the model.

See the Rainier repo for further details.

Monadic probabilistic programming in Scala with Rainier

Introduction

Rainier is an interesting new probabilistic programming library for Scala recently open-sourced by Stripe. Probabilistic programming languages provide a computational framework for building and fitting Bayesian models to data. There are many interesting probabilistic programming languages, and there is currently a lot of interesting innovation happening with probabilistic programming languages embedded in strongly typed functional programming languages such as Scala and Haskell. However, most such languages tend to be developed by people lacking expertise in statistics and numerics, leading to elegant, composable languages which work well for toy problems, but don’t scale well to the kinds of practical problems that applied statisticians are interested in. Conversely, there are a few well-known probabilistic programming languages developed by and for statisticians which have efficient inference engines, but are hampered by inflexible, inelegant languages and APIs. Rainier is interesting because it is an attempt to bridge the gap between these two worlds: it has a functional, composable, extensible, monadic API, yet is backed by a very efficient, high-performance scalable inference engine, using HMC and a static compute graph for reverse-mode AD. Clearly there will be some loss of generality associated with choosing an efficient inference algorithm (eg. for HMC, there needs to be a fixed number of parameters and they must all be continuous), but it still covers a large proportion of the class of hierarchical models commonly used in applied statistical modelling.

In this post I’ll give a quick introduction to Rainier using an interactive session requiring only that SBT is installed and the Rainier repo is downloaded or cloned.

Interactive session

To follow along with this post just clone, or download and unpack, the Rainier repo, and run SBT from the top-level Rainier directory and paste commands. First start a Scala REPL.

project rainierPlot
console

Before we start building models, we need some data. For this post we will focus on a simple logistic regression model, and so we will begin by simulating some synthetic data consistent with such a model.

val r = new scala.util.Random(0)
val N = 1000
val beta0 = 0.1
val beta1 = 0.3
val x = (1 to N) map { i =>
  3.0 * r.nextGaussian
}
val theta = x map { xi =>
  beta0 + beta1 * xi
}
def expit(x: Double): Double = 1.0 / (1.0 + math.exp(-x))
val p = theta map expit
val y = p map (pi => (r.nextDouble < pi))

Now we have some synthetic data, we can fit the model and see if we are able to recover the “true” parameters used to generate the synthetic data. In Rainier, we build models by declaring probabilistic programs for the model and the data, and then run an inference engine to generate samples from the posterior distribution.

Start with a bunch of Rainier imports:

import com.stripe.rainier.compute._
import com.stripe.rainier.core._
import com.stripe.rainier.sampler._
import com.stripe.rainier.repl._

Now we want to build a model. We do so by describing the joint distribution of parameters and data. Rainier has a few built-in distributions, and these can be combined using standard functional monadic combinators such as map, zip, flatMap, etc., to create a probabilistic program representing a probability monad for the model. Due to the monadic nature of such probabilistic programs, it is often most natural to declare them using a for-expression.

val model = for {
  beta0 <- Normal(0, 5).param
  beta1 <- Normal(0, 5).param
  _ <- Predictor.from{x: Double =>
      {
        val theta = beta0 + beta1 * x
        val p = Real(1.0) / (Real(1.0) + (Real(0.0) - theta).exp)
        Categorical.boolean(p)
      }
    }.fit(x zip y)
} yield Map("b0"->beta0, "b1"->beta1)

This kind of construction is very natural for anyone familiar with monadic programming in Scala, but will no doubt be a little mysterious otherwise. RandomVariable is the probability monad used for HMC sampling, and these can be constructed from Distributions using .param (for unobserved parameters) and .fit (for variables with associated observations). Predictor is just a convenience for observations corresponding to covariate information. model is therefore a RandomVariable over beta0 and beta1, the two unobserved parameters of interest. Note that I briefly discussed this kind of pure functional approach to describing probabilistic programs (using Rand from Breeze) in my post on MCMC as a stream.

Now we have our probabilistic program, we can sample from it using HMC as follows.

implicit val rng = ScalaRNG(3)
val its = 10000
val thin = 5
val out = model.sample(HMC(5), 10000, its*thin, thin)
println(out.take(10))

The argument to HMC() is the number of leapfrog steps to take per iteration.

Finally, we can use EvilPlot to look at the HMC output and check that we have managed to reasonably recover the true parameters associated with our synthetic data.

import com.cibo.evilplot.geometry.Extent
import com.stripe.rainier.plot.EvilTracePlot._

render(traces(out, truth = Map("b0" -> beta0, "b1" -> beta1)),
  "traceplots.png", Extent(1200, 1000))
render(pairs(out, truth = Map("b0" -> beta0, "b1" -> beta1)), "pairs.png")

Everything looks good, and the sampling is very fast!

Further reading

For further information, see the Rainier repo. In particular, start with the tour of Rainier’s core, which gives a more detailed introduction to how Rainier works than this post. Those interested in how the efficient AD works may want to read about the compute graph, and the implementation notes explain how it all fits together. There is some basic ScalaDoc for the core package, and also some examples (including this one), and there’s a gitter channel for asking questions. This is a very new project, so there are a few minor bugs and wrinkles in the initial release, but development is progressing rapidly, so I fully expect the library to get properly battle-hardened over the next few months.

For those unfamiliar with the monadic approach to probabilistic programming, then Ścibior et al (2015) is probably a good starting point.

Comonads for scientific and statistical computing in Scala

Introduction

In a previous post I’ve given a brief introduction to monads in Scala, aimed at people interested in scientific and statistical computing. Monads are a concept from category theory which turn out to be exceptionally useful for solving many problems in functional programming. But most categorical concepts have a dual, usually prefixed with “co”, so the dual of a monad is the comonad. Comonads turn out to be especially useful for formulating algorithms from scientific and statistical computing in an elegant way. In this post I’ll illustrate their use in signal processing, image processing, numerical integration of PDEs, and Gibbs sampling (of an Ising model). Comonads enable the extension of a local computation to a global computation, and this pattern crops up all over the place in statistical computing.

Monads and comonads

Simplifying massively, from the viewpoint of a Scala programmer, a monad is a mappable (functor) type class augmented with the methods pure and flatMap:

trait Monad[M[_]] extends Functor[M] {
  def pure[T](v: T): M[T]
  def flatMap[T,S](v: M[T])(f: T => M[S]): M[S]
}

In category theory, the dual of a concept is typically obtained by “reversing the arrows”. Here that means reversing the direction of the methods pure and flatMap to get extract and coflatMap, respectively.

trait Comonad[W[_]] extends Functor[W] {
  def extract[T](v: W[T]): T
  def coflatMap[T,S](v: W[T])(f: W[T] => S): W[S]
}

So, while pure allows you to wrap plain values in a monad, extract allows you to get a value out of a comonad. So you can always get a value out of a comonad (unlike a monad). Similarly, while flatMap allows you to transform a monad using a function returning a monad, coflatMap allows you to transform a comonad using a function which collapses a comonad to a single value. It is coflatMap (sometimes called extend) which can extend a local computation (producing a single value) to the entire comonad. We’ll look at how that works in the context of some familiar examples.

Applying a linear filter to a data stream

One of the simplest examples of a comonad is an infinite stream of data. I’ve discussed streams in a previous post. By focusing on infinite streams we know the stream will never be empty, so there will always be a value that we can extract. Which value does extract give? For a Stream encoded as some kind of lazy list, the only value we actually know is the value at the head of the stream, with subsequent values to be lazily computed as required. So the head of the list is the only reasonable value for extract to return.

Understanding coflatMap is a bit more tricky, but it is coflatMap that provides us with the power to apply a non-trivial statistical computation to the stream. The input is a function which transforms a stream into a value. In our example, that will be a function which computes a weighted average of the first few values and returns that weighted average as the result. But the return type of coflatMap must be a stream of such computations. Following the types, a few minutes thought reveals that the only reasonable thing to do is to return the stream formed by applying the weighted average function to all sub-streams, recursively. So, for a Stream s (of type Stream[T]) and an input function f: W[T] => S, we form a stream whose head is f(s) and whose tail is coflatMap(f) applied to s.tail. Again, since we are working with an infinite stream, we don’t have to worry about whether or not the tail is empty. This gives us our comonadic Stream, and it is exactly what we need for applying a linear filter to the data stream.

In Scala, Cats is a library providing type classes from Category theory, and instances of those type classes for parametrised types in the standard library. In particular, it provides us with comonadic functionality for the standard Scala Stream. Let’s start by defining a stream corresponding to the logistic map.

import cats._
import cats.implicits._

val lam = 3.7
def s = Stream.iterate(0.5)(x => lam*x*(1-x))
s.take(10).toList
// res0: List[Double] = List(0.5, 0.925, 0.25668749999999985,
//  0.7059564011718747, 0.7680532550204203, 0.6591455741499428, ...

Let us now suppose that we want to apply a linear filter to this stream, in order to smooth the values. The idea behind using comonads is that you figure out how to generate one desired value, and let coflatMap take care of applying the same logic to the rest of the structure. So here, we need a function to generate the first filtered value (since extract is focused on the head of the stream). A simple first attempt a function to do this might look like the following.

  def linearFilterS(weights: Stream[Double])(s: Stream[Double]): Double =
    (weights, s).parMapN(_*_).sum

This aligns each weight in parallel with a corresponding value from the stream, and combines them using multiplication. The resulting (hopefully finite length) stream is then summed (with addition). We can test this with

linearFilterS(Stream(0.25,0.5,0.25))(s)
// res1: Double = 0.651671875

and let coflatMap extend this computation to the rest of the stream with something like:

s.coflatMap(linearFilterS(Stream(0.25,0.5,0.25))).take(5).toList
// res2: List[Double] = List(0.651671875, 0.5360828502929686, ...

This is all completely fine, but our linearFilterS function is specific to the Stream comonad, despite the fact that all we’ve used about it in the function is that it is a parallelly composable and foldable. We can make this much more generic as follows:

  def linearFilter[F[_]: Foldable, G[_]](
    weights: F[Double], s: F[Double]
  )(implicit ev: NonEmptyParallel[F, G]): Double =
    (weights, s).parMapN(_*_).fold

This uses some fairly advanced Scala concepts which I don’t want to get into right now (I should also acknowledge that I had trouble getting the syntax right for this, and got help from Fabio Labella (@SystemFw) on the Cats gitter channel). But this version is more generic, and can be used to linearly filter other data structures than Stream. We can use this for regular Streams as follows:

s.coflatMap(s => linearFilter(Stream(0.25,0.5,0.25),s))
// res3: scala.collection.immutable.Stream[Double] = Stream(0.651671875, ?)

But we can apply this new filter to other collections. This could be other, more sophisticated, streams such as provided by FS2, Monix or Akka streams. But it could also be a non-stream collection, such as List:

val sl = s.take(10).toList
sl.coflatMap(sl => linearFilter(List(0.25,0.5,0.25),sl))
// res4: List[Double] = List(0.651671875, 0.5360828502929686, ...

Assuming that we have the Breeze scientific library available, we can plot the raw and smoothed trajectories.

def myFilter(s: Stream[Double]): Double =
  linearFilter(Stream(0.25, 0.5, 0.25),s)
val n = 500
import breeze.plot._
import breeze.linalg._
val fig = Figure(s"The (smoothed) logistic map (lambda=$lam)")
val p0 = fig.subplot(3,1,0)
p0 += plot(linspace(1,n,n),s.take(n))
p0.ylim = (0.0,1.0)
p0.title = s"The logistic map (lambda=$lam)"
val p1 = fig.subplot(3,1,1)
p1 += plot(linspace(1,n,n),s.coflatMap(myFilter).take(n))
p1.ylim = (0.0,1.0)
p1.title = "Smoothed by a simple linear filter"
val p2 = fig.subplot(3,1,2)
p2 += plot(linspace(1,n,n),s.coflatMap(myFilter).coflatMap(myFilter).coflatMap(myFilter).coflatMap(myFilter).coflatMap(myFilter).take(n))
p2.ylim = (0.0,1.0)
p2.title = "Smoothed with 5 applications of the linear filter"
fig.refresh

Image processing and the heat equation

Streaming data is in no way the only context in which a comonadic approach facilitates an elegant approach to scientific and statistical computing. Comonads crop up anywhere where we want to extend a computation that is local to a small part of a data structure to the full data structure. Another commonly cited area of application of comonadic approaches is image processing (I should acknowledge that this section of the post is very much influenced by a blog post on comonadic image processing in Haskell). However, the kinds of operations used in image processing are in many cases very similar to the operations used in finite difference approaches to numerical integration of partial differential equations (PDEs) such as the heat equation, so in this section I will blur (sic) the distinction between the two, and numerically integrate the 2D heat equation in order to Gaussian blur a noisy image.

First we need a simple image type which can have pixels of arbitrary type T (this is very important – all functors must be fully type polymorphic).

  import scala.collection.parallel.immutable.ParVector
  case class Image[T](w: Int, h: Int, data: ParVector[T]) {
    def apply(x: Int, y: Int): T = data(x*h+y)
    def map[S](f: T => S): Image[S] = Image(w, h, data map f)
    def updated(x: Int, y: Int, value: T): Image[T] =
      Image(w,h,data.updated(x*h+y,value))
  }

Here I’ve chosen to back the image with a parallel immutable vector. This wasn’t necessary, but since this type has a map operation which automatically parallelises over multiple cores, any map operations applied to the image will be automatically parallelised. This will ultimately lead to all of our statistical computations being automatically parallelised without us having to think about it.

As it stands, this image isn’t comonadic, since it doesn’t implement extract or coflatMap. Unlike the case of Stream, there isn’t really a uniquely privileged pixel, so it’s not clear what extract should return. For many data structures of this type, we make them comonadic by adding a “cursor” pointing to a “current” element of interest, and use this as the focus for computations applied with coflatMap. This is simplest to explain by example. We can define our “pointed” image type as follows:

  case class PImage[T](x: Int, y: Int, image: Image[T]) {
    def extract: T = image(x, y)
    def map[S](f: T => S): PImage[S] = PImage(x, y, image map f)
    def coflatMap[S](f: PImage[T] => S): PImage[S] = PImage(
      x, y, Image(image.w, image.h,
      (0 until (image.w * image.h)).toVector.par.map(i => {
        val xx = i / image.h
        val yy = i % image.h
        f(PImage(xx, yy, image))
      })))

There is missing a closing brace, as I’m not quite finished. Here x and y represent the location of our cursor, so extract returns the value of the pixel indexed by our cursor. Similarly, coflatMap forms an image where the value of the image at each location is the result of applying the function f to the image which had the cursor set to that location. Clearly f should use the cursor in some way, otherwise the image will have the same value at every pixel location. Note that map and coflatMap operations will be automatically parallelised. The intuitive idea behind coflatMap is that it extends local computations. For the stream example, the local computation was a linear combination of nearby values. Similarly, in image analysis problems, we often want to apply a linear filter to nearby pixels. We can get at the pixel at the cursor location using extract, but we probably also want to be able to move the cursor around to nearby locations. We can do that by adding some appropriate methods to complete the class definition.

    def up: PImage[T] = {
      val py = y-1
      val ny = if (py >= 0) py else (py + image.h)
      PImage(x,ny,image)
    }
    def down: PImage[T] = {
      val py = y+1
      val ny = if (py < image.h) py else (py - image.h)
      PImage(x,ny,image)
    }
    def left: PImage[T] = {
      val px = x-1
      val nx = if (px >= 0) px else (px + image.w)
      PImage(nx,y,image)
    }
    def right: PImage[T] = {
      val px = x+1
      val nx = if (px < image.w) px else (px - image.w)
      PImage(nx,y,image)
    }
  }

Here each method returns a new pointed image with the cursor shifted by one pixel in the appropriate direction. Note that I’ve used periodic boundary conditions here, which often makes sense for numerical integration of PDEs, but makes less sense for real image analysis problems. Note that we have embedded all “indexing” issues inside the definition of our classes. Now that we have it, none of the statistical algorithms that we develop will involve any explicit indexing. This makes it much less likely to develop algorithms containing bugs corresponding to “off-by-one” or flipped axis errors.

This class is now fine for our requirements. But if we wanted Cats to understand that this structure is really a comonad (perhaps because we wanted to use derived methods, such as coflatten), we would need to provide evidence for this. The details aren’t especially important for this post, but we can do it simply as follows:

  implicit val pimageComonad = new Comonad[PImage] {
    def extract[A](wa: PImage[A]) = wa.extract
    def coflatMap[A,B](wa: PImage[A])(f: PImage[A] => B): PImage[B] =
      wa.coflatMap(f)
    def map[A,B](wa: PImage[A])(f: A => B): PImage[B] = wa.map(f)
  }

It’s handy to have some functions for converting Breeze dense matrices back and forth with our image class.

  import breeze.linalg.{Vector => BVec, _}
  def BDM2I[T](m: DenseMatrix[T]): Image[T] =
    Image(m.cols, m.rows, m.data.toVector.par)
  def I2BDM(im: Image[Double]): DenseMatrix[Double] =
    new DenseMatrix(im.h,im.w,im.data.toArray)

Now we are ready to see how to use this in practice. Let’s start by defining a very simple linear filter.

def fil(pi: PImage[Double]): Double = (2*pi.extract+
  pi.up.extract+pi.down.extract+pi.left.extract+pi.right.extract)/6.0

This simple filter can be used to “smooth” or “blur” an image. However, from a more sophisticated viewpoint, exactly this type of filter can be used to represent one time step of a numerical method for time integration of the 2D heat equation. Now we can simulate a noisy image and apply our filter to it using coflatMap:

import breeze.stats.distributions.Gaussian
val bdm = DenseMatrix.tabulate(200,250){case (i,j) => math.cos(
  0.1*math.sqrt((i*i+j*j))) + Gaussian(0.0,2.0).draw}
val pim0 = PImage(0,0,BDM2I(bdm))
def pims = Stream.iterate(pim0)(_.coflatMap(fil))

Note that here, rather than just applying the filter once, I’ve generated an infinite stream of pointed images, each one representing an additional application of the linear filter. Thus the sequence represents the time solution of the heat equation with initial condition corresponding to our simulated noisy image.

We can render the first few frames to check that it seems to be working.

import breeze.plot._
val fig = Figure("Diffusing a noisy image")
pims.take(25).zipWithIndex.foreach{case (pim,i) => {
  val p = fig.subplot(5,5,i)
  p += image(I2BDM(pim.image))
}}

Note that the numerical integration is carried out in parallel on all available cores automatically. Other image filters can be applied, and other (parabolic) PDEs can be numerically integrated in an essentially similar way.

Gibbs sampling the Ising model

Another place where the concept of extending a local computation to a global computation crops up is in the context of Gibbs sampling a high-dimensional probability distribution by cycling through the sampling of each variable in turn from its full-conditional distribution. I’ll illustrate this here using the Ising model, so that I can reuse the pointed image class from above, but the principles apply to any Gibbs sampling problem. In particular, the Ising model that we consider has a conditional independence structure corresponding to a graph of a square lattice. As above, we will use the comonadic structure of the square lattice to construct a Gibbs sampler. However, we can construct a Gibbs sampler for arbitrary graphical models in an essentially identical way by using a graph comonad.

Let’s begin by simulating a random image containing +/-1s:

import breeze.stats.distributions.{Binomial,Bernoulli}
val beta = 0.4
val bdm = DenseMatrix.tabulate(500,600){
  case (i,j) => (new Binomial(1,0.2)).draw
}.map(_*2 - 1) // random matrix of +/-1s
val pim0 = PImage(0,0,BDM2I(bdm))

We can use this to initialise our Gibbs sampler. We now need a Gibbs kernel representing the update of each pixel.

def gibbsKernel(pi: PImage[Int]): Int = {
   val sum = pi.up.extract+pi.down.extract+pi.left.extract+pi.right.extract
   val p1 = math.exp(beta*sum)
   val p2 = math.exp(-beta*sum)
   val probplus = p1/(p1+p2)
   if (new Bernoulli(probplus).draw) 1 else -1
}

So far so good, but there a couple of issues that we need to consider before we plough ahead and start coflatMapping. The first is that pure functional programmers will object to the fact that this function is not pure. It is a stochastic function which has the side-effect of mutating the random number state. I’m just going to duck that issue here, as I’ve previously discussed how to fix it using probability monads, and I don’t want it to distract us here.

However, there is a more fundamental problem here relating to parallel versus sequential application of Gibbs kernels. coflatMap is conceptually parallel (irrespective of how it is implemented) in that all computations used to build the new comonad are based solely on the information available in the starting comonad. OTOH, detailed balance of the Markov chain will only be preserved if the kernels for each pixel are applied sequentially. So if we coflatMap this kernel over the image we will break detailed balance. I should emphasise that this has nothing to do with the fact that I’ve implemented the pointed image using a parallel vector. Exactly the same issue would arise if we switched to backing the image with a regular (sequential) immutable Vector.

The trick here is to recognise that if we coloured alternate pixels black and white using a chequerboard pattern, then all of the black pixels are conditionally independent given the white pixels and vice-versa. Conditionally independent pixels can be updated by parallel application of a Gibbs kernel. So we just need separate kernels for updating odd and even pixels.

def oddKernel(pi: PImage[Int]): Int =
  if ((pi.x+pi.y) % 2 != 0) pi.extract else gibbsKernel(pi)
def evenKernel(pi: PImage[Int]): Int =
  if ((pi.x+pi.y) % 2 == 0) pi.extract else gibbsKernel(pi)

Each of these kernels can be coflatMapped over the image preserving detailed balance of the chain. So we can now construct an infinite stream of MCMC iterations as follows.

def pims = Stream.iterate(pim0)(_.coflatMap(oddKernel).
  coflatMap(evenKernel))

We can animate the first few iterations with:

import breeze.plot._
val fig = Figure("Ising model Gibbs sampler")
fig.width = 1000
fig.height = 800
pims.take(50).zipWithIndex.foreach{case (pim,i) => {
  print(s"$i ")
  fig.clear
  val p = fig.subplot(1,1,0)
  p.title = s"Ising model: frame $i"
  p += image(I2BDM(pim.image.map{_.toDouble}))
  fig.refresh
}}
println

Here I have a movie showing the first 1000 iterations. Note that youtube seems to have over-compressed it, but you should get the basic idea.

Again, note that this MCMC sampler runs in parallel on all available cores, automatically. This issue of odd/even pixel updating emphasises another issue that crops up a lot in functional programming: very often, thinking about how to express an algorithm functionally leads to an algorithm which parallelises naturally. For general graphs, figuring out which groups of nodes can be updated in parallel is essentially the graph colouring problem. I’ve discussed this previously in relation to parallel MCMC in:

Wilkinson, D. J. (2005) Parallel Bayesian Computation, Chapter 16 in E. J. Kontoghiorghes (ed.) Handbook of Parallel Computing and Statistics, Marcel Dekker/CRC Press, 481-512.

Further reading

There are quite a few blog posts discussing comonads in the context of Haskell. In particular, the post on comonads for image analysis I mentioned previously, and this one on cellular automata. Bartosz’s post on comonads gives some connection back to the mathematical origins. Runar’s Scala comonad tutorial is the best source I know for comonads in Scala.

Full runnable code corresponding to this blog post is available from my blog repo.

Statistical computing with Scala free on-line course

I’ve recently delivered a three-day intensive short-course on Scala for statistical computing and data science. The course seemed to go well, and the experience has convinced me that Scala should be used a lot more by statisticians and data scientists for a range of problems in statistical computing. In particular, the simplicity of writing fast efficient parallel algorithms is reason alone to take a careful look at Scala. With a view to helping more statisticians get to grips with Scala, I’ve decided to freely release all of the essential materials associated with the course: the course notes (as PDF), code fragments, complete examples, end-of-chapter exercises, etc. Although I developed the materials with the training course in mind, the course notes are reasonably self-contained, making the course quite suitable for self-study. At some point I will probably flesh out the notes into a proper book, but that will probably take me a little while.

I’ve written a brief self-study guide to point people in the right direction. For people studying the material in their spare time, the course is probably best done over nine weeks (one chapter per week), and this will then cover material at a similar rate to a typical MOOC.

The nine chapters are:

1. Introduction
2. Scala and FP Basics
3. Collections
4. Scala Breeze
5. Monte Carlo
6. Statistical modelling
7. Tools
8. Apache Spark
9. Advanced topics

For anyone frustrated by the limitations of dynamic languages such as R, Python or Octave, this course should provide a good pathway to an altogether more sophisticated, modern programming paradigm.

MCMC as a Stream

Introduction

This weekend I’ve been preparing some material for my upcoming Scala for statistical computing short course. As part of the course, I thought it would be useful to walk through how to think about and structure MCMC codes, and in particular, how to think about MCMC algorithms as infinite streams of state. This material is reasonably stand-alone, so it seems suitable for a blog post. Complete runnable code for the examples in this post are available from my blog repo.

A simple MH sampler

For this post I will just consider a trivial toy Metropolis algorithm using a Uniform random walk proposal to target a standard normal distribution. I’ve considered this problem before on my blog, so if you aren’t very familiar with Metropolis-Hastings algorithms, you might want to quickly review my post on Metropolis-Hastings MCMC algorithms in R before continuing. At the end of that post, I gave the following R code for the Metropolis sampler:

metrop3<-function(n=1000,eps=0.5) 
{
        vec=vector("numeric", n)
        x=0
        oldll=dnorm(x,log=TRUE)
        vec[1]=x
        for (i in 2:n) {
                can=x+runif(1,-eps,eps)
                loglik=dnorm(can,log=TRUE)
                loga=loglik-oldll
                if (log(runif(1)) < loga) { 
                        x=can
                        oldll=loglik
                        }
                vec[i]=x
        }
        vec
}

I will begin this post with a fairly direct translation of this algorithm into Scala:

def metrop1(n: Int = 1000, eps: Double = 0.5): DenseVector[Double] = {
    val vec = DenseVector.fill(n)(0.0)
    var x = 0.0
    var oldll = Gaussian(0.0, 1.0).logPdf(x)
    vec(0) = x
    (1 until n).foreach { i =>
      val can = x + Uniform(-eps, eps).draw
      val loglik = Gaussian(0.0, 1.0).logPdf(can)
      val loga = loglik - oldll
      if (math.log(Uniform(0.0, 1.0).draw) < loga) {
        x = can
        oldll = loglik
      }
      vec(i) = x
    }
    vec
}

This code works, and is reasonably fast and efficient, but there are several issues with it from a functional programmers perspective. One issue is that we have committed to storing all MCMC output in RAM in a DenseVector. This probably isn’t an issue here, but for some big problems we might prefer to not store the full set of states, but to just print the states to (say) the console, for possible re-direction to a file. It is easy enough to modify the code to do this:

def metrop2(n: Int = 1000, eps: Double = 0.5): Unit = {
    var x = 0.0
    var oldll = Gaussian(0.0, 1.0).logPdf(x)
    (1 to n).foreach { i =>
      val can = x + Uniform(-eps, eps).draw
      val loglik = Gaussian(0.0, 1.0).logPdf(can)
      val loga = loglik - oldll
      if (math.log(Uniform(0.0, 1.0).draw) < loga) {
        x = can
        oldll = loglik
      }
      println(x)
    }
}

But now we have two version of the algorithm. One for storing results locally, and one for streaming results to the console. This is clearly unsatisfactory, but we shall return to this issue shortly. Another issue that will jump out at functional programmers is the reliance on mutable variables for storing the state and old likelihood. Let’s fix that now by re-writing the algorithm as a tail-recursion.

@tailrec
def metrop3(n: Int = 1000, eps: Double = 0.5, x: Double = 0.0, oldll: Double = Double.MinValue): Unit = {
    if (n > 0) {
      println(x)
      val can = x + Uniform(-eps, eps).draw
      val loglik = Gaussian(0.0, 1.0).logPdf(can)
      val loga = loglik - oldll
      if (math.log(Uniform(0.0, 1.0).draw) < loga)
        metrop3(n - 1, eps, can, loglik)
      else
        metrop3(n - 1, eps, x, oldll)
    }
  }

This has eliminated the vars, and is just as fast and efficient as the previous version of the code. Note that the @tailrec annotation is optional – it just signals to the compiler that we want it to throw an error if for some reason it cannot eliminate the tail call. However, this is for the print-to-console version of the code. What if we actually want to keep the iterations in RAM for subsequent analysis? We can keep the values in an accumulator, as follows.

@tailrec
def metrop4(n: Int = 1000, eps: Double = 0.5, x: Double = 0.0, oldll: Double = Double.MinValue, acc: List[Double] = Nil): DenseVector[Double] = {
    if (n == 0)
      DenseVector(acc.reverse.toArray)
    else {
      val can = x + Uniform(-eps, eps).draw
      val loglik = Gaussian(0.0, 1.0).logPdf(can)
      val loga = loglik - oldll
      if (math.log(Uniform(0.0, 1.0).draw) < loga)
        metrop4(n - 1, eps, can, loglik, can :: acc)
      else
        metrop4(n - 1, eps, x, oldll, x :: acc)
    }
}

Factoring out the updating logic

This is all fine, but we haven’t yet addressed the issue of having different versions of the code depending on what we want to do with the output. The problem is that we have tied up the logic of advancing the Markov chain with what to do with the output. What we need to do is separate out the code for advancing the state. We can do this by defining a new function.

def newState(x: Double, oldll: Double, eps: Double): (Double, Double) = {
    val can = x + Uniform(-eps, eps).draw
    val loglik = Gaussian(0.0, 1.0).logPdf(can)
    val loga = loglik - oldll
    if (math.log(Uniform(0.0, 1.0).draw) < loga) (can, loglik) else (x, oldll)
}

This function takes as input a current state and associated log likelihood and returns a new state and log likelihood following the execution of one step of a MH algorithm. This separates the concern of state updating from the rest of the code. So now if we want to write code that prints the state, we can write it as

  @tailrec
  def metrop5(n: Int = 1000, eps: Double = 0.5, x: Double = 0.0, oldll: Double = Double.MinValue): Unit = {
    if (n > 0) {
      println(x)
      val ns = newState(x, oldll, eps)
      metrop5(n - 1, eps, ns._1, ns._2)
    }
  }

and if we want to accumulate the set of states visited, we can write that as

  @tailrec
  def metrop6(n: Int = 1000, eps: Double = 0.5, x: Double = 0.0, oldll: Double = Double.MinValue, acc: List[Double] = Nil): DenseVector[Double] = {
    if (n == 0) DenseVector(acc.reverse.toArray) else {
      val ns = newState(x, oldll, eps)
      metrop6(n - 1, eps, ns._1, ns._2, ns._1 :: acc)
    }
  }

Both of these functions call newState to do the real work, and concentrate on what to do with the sequence of states. However, both of these functions repeat the logic of how to iterate over the sequence of states.

MCMC as a stream

Ideally we would like to abstract out the details of how to do state iteration from the code as well. Most functional languages have some concept of a Stream, which represents a (potentially infinite) sequence of states. The Stream can embody the logic of how to perform state iteration, allowing us to abstract that away from our code, as well.

To do this, we will restructure our code slightly so that it more clearly maps old state to new state.

def nextState(eps: Double)(state: (Double, Double)): (Double, Double) = {
    val x = state._1
    val oldll = state._2
    val can = x + Uniform(-eps, eps).draw
    val loglik = Gaussian(0.0, 1.0).logPdf(can)
    val loga = loglik - oldll
    if (math.log(Uniform(0.0, 1.0).draw) < loga) (can, loglik) else (x, oldll)
}

The "real" state of the chain is just x, but if we want to avoid recalculation of the old likelihood, then we need to make this part of the chain’s state. We can use this nextState function in order to construct a Stream.

  def metrop7(eps: Double = 0.5, x: Double = 0.0, oldll: Double = Double.MinValue): Stream[Double] =
    Stream.iterate((x, oldll))(nextState(eps)) map (_._1)

The result of calling this is an infinite stream of states. Obviously it isn’t computed – that would require infinite computation, but it captures the logic of iteration and computation in a Stream, that can be thought of as a lazy List. We can get values out by converting the Stream to a regular collection, being careful to truncate the Stream to one of finite length beforehand! eg. metrop7().drop(1000).take(10000).toArray will do a burn-in of 1,000 iterations followed by a main monitoring run of length 10,000, capturing the results in an Array. Note that metrop7().drop(1000).take(10000) is a Stream, and so nothing is actually computed until the toArray is encountered. Conversely, if printing to console is required, just replace the .toArray with .foreach(println).

The above stream-based approach to MCMC iteration is clean and elegant, and deals nicely with issues like burn-in and thinning (which can be handled similarly). This is how I typically write MCMC codes these days. However, functional programming purists would still have issues with this approach, as it isn’t quite pure functional. The problem is that the code isn’t pure – it has a side-effect, which is to mutate the state of the under-pinning pseudo-random number generator. If the code was pure, calling nextState with the same inputs would always give the same result. Clearly this isn’t the case here, as we have specifically designed the function to be stochastic, returning a randomly sampled value from the desired probability distribution. So nextState represents a function for randomly sampling from a conditional probability distribution.

A pure functional approach

Now, ultimately all code has side-effects, or there would be no point in running it! But in functional programming the desire is to make as much of the code as possible pure, and to push side-effects to the very edges of the code. So it’s fine to have side-effects in your main method, but not buried deep in your code. Here the side-effect is at the very heart of the code, which is why it is potentially an issue.

To keep things as simple as possible, at this point we will stop worrying about carrying forward the old likelihood, and hard-code a value of eps. Generalisation is straightforward. We can make our code pure by instead defining a function which represents the conditional probability distribution itself. For this we use a probability monad, which in Breeze is called Rand. We can couple together such functions using monadic binds (flatMap in Scala), expressed most neatly using for-comprehensions. So we can write our transition kernel as

def kernel(x: Double): Rand[Double] = for {
    innov <- Uniform(-0.5, 0.5)
    can = x + innov
    oldll = Gaussian(0.0, 1.0).logPdf(x)
    loglik = Gaussian(0.0, 1.0).logPdf(can)
    loga = loglik - oldll
    u <- Uniform(0.0, 1.0)
} yield if (math.log(u) < loga) can else x

This is now pure – the same input x will always return the same probability distribution – the conditional distribution of the next state given the current state. We can draw random samples from this distribution if we must, but it’s probably better to work as long as possible with pure functions. So next we need to encapsulate the iteration logic. Breeze has a MarkovChain object which can take kernels of this form and return a stochastic Process object representing the iteration logic, as follows.

MarkovChain(0.0)(kernel).
  steps.
  drop(1000).
  take(10000).
  foreach(println)

The steps method contains the logic of how to advance the state of the chain. But again note that no computation actually takes place until the foreach method is encountered – this is when the sampling occurs and the side-effects happen.

Metropolis-Hastings is a common use-case for Markov chains, so Breeze actually has a helper method built-in that will construct a MH sampler directly from an initial state, a proposal kernel, and a (log) target.

MarkovChain.
  metropolisHastings(0.0, (x: Double) =>
  Uniform(x - 0.5, x + 0.5))(x =>
  Gaussian(0.0, 1.0).logPdf(x)).
  steps.
  drop(1000).
  take(10000).
  toArray

Note that if you are using the MH functionality in Breeze, it is important to make sure that you are using version 0.13 (or later), as I fixed a few issues with the MH code shortly prior to the 0.13 release.

Summary

Viewing MCMC algorithms as infinite streams of state is useful for writing elegant, generic, flexible code. Streams occur everywhere in programming, and so there are lots of libraries for working with them. In this post I used the simple Stream from the Scala standard library, but there are much more powerful and flexible stream libraries for Scala, including fs2 and Akka-streams. But whatever libraries you are using, the fundamental concepts are the same. The most straightforward approach to implementation is to define impure stochastic streams to consume. However, a pure functional approach is also possible, and the Breeze library defines some useful functions to facilitate this approach. I’m still a little bit ambivalent about whether the pure approach is worth the additional cognitive overhead, but it’s certainly very interesting and worth playing with and thinking about the pros and cons.

Complete runnable code for the examples in this post are available from my blog repo.