A probability monad for the bootstrap particle filter

Introduction

In the previous post I showed how to write your own general-purpose monadic probabilistic programming language from scratch in 50 lines of (Scala) code. That post is a pre-requisite for this one, so if you haven’t read it, go back and have a quick skim through it before proceeding. In that post I tried to keep everything as simple as possible, but at the expense of both elegance and efficiency. In this post I’ll address one problem with the implementation from that post – the memory (and computational) overhead associated with forming the Cartesian product of particle sets during monadic binding (flatMap). So if particle sets are typically of size N, then the Cartesian product is of size N^2, and multinomial resampling is applied to this set of size N^2 in order to sample back down to a set of size N. But this isn’t actually necessary. We can directly construct a set of size N, certainly saving memory, but also potentially saving computation time if the conditional distribution (on the right of the monadic bind) can be efficiently sampled. If we do this we will have a probability monad encapsulating the logic of a bootstrap particle filter, such as is often used for computing the filtering distribution of a state-space model in time series analysis. This simple change won’t solve the computational issues associated with deep monadic binding, but does solve the memory problem, and can lead to computationally efficient algorithms so long as care is taken in the formulation of probabilistic programs to ensure that deep monadic binding doesn’t occur. We’ll discuss that issue in the context of state-space models later, once we have our new SMC-based probability monad.

Materials for this post can be found in my blog repo, and a draft of this post itself can be found in the form of an executable tut document.

An SMC-based monad

The idea behind the approach to binding used in this monad is to mimic the “predict” step of a bootstrap particle filter. Here, for each particle in the source distribution, exactly one particle is drawn from the required conditional distribution and paired with the source particle, preserving the source particle’s original weight. So, in order to operationalise this, we will need a draw method adding into our probability monad. It will also simplify things to add a flatMap method to our Particle type constructor.

To follow along, you can type sbt console from the min-ppl2 directory of my blog repo, then paste blocks of code one at a time.

  import breeze.stats.{distributions => bdist}
  import breeze.linalg.DenseVector
  import cats._
  import cats.implicits._

  implicit val numParticles = 2000

  case class Particle[T](v: T, lw: Double) { // value and log-weight
    def map[S](f: T => S): Particle[S] = Particle(f(v), lw)
    def flatMap[S](f: T => Particle[S]): Particle[S] = {
      val ps = f(v)
      Particle(ps.v, lw + ps.lw)
    }
  }

I’ve added a dependence on cats here, so that we can use some derived methods, later. To take advantage of this, we must provide evidence that our custom types conform to standard type class interfaces. For example, we can provide evidence that Particle[_] is a monad as follows.

  implicit val particleMonad = new Monad[Particle] {
    def pure[T](t: T): Particle[T] = Particle(t, 0.0)
    def flatMap[T,S](pt: Particle[T])(f: T => Particle[S]): Particle[S] = pt.flatMap(f)
    def tailRecM[T,S](t: T)(f: T => Particle[Either[T,S]]): Particle[S] = ???
  }

The technical details are not important for this post, but we’ll see later what this can give us.

We can now define our Prob[_] monad in the following way.

  trait Prob[T] {
    val particles: Vector[Particle[T]]
    def draw: Particle[T]
    def mapP[S](f: T => Particle[S]): Prob[S] = Empirical(particles map (_ flatMap f))
    def map[S](f: T => S): Prob[S] = mapP(v => Particle(f(v), 0.0))
    def flatMap[S](f: T => Prob[S]): Prob[S] = mapP(f(_).draw)
    def resample(implicit N: Int): Prob[T] = {
      val lw = particles map (_.lw)
      val mx = lw reduce (math.max(_,_))
      val rw = lw map (lwi => math.exp(lwi - mx))
      val law = mx + math.log(rw.sum/(rw.length))
      val ind = bdist.Multinomial(DenseVector(rw.toArray)).sample(N)
      val newParticles = ind map (i => particles(i))
      Empirical(newParticles.toVector map (pi => Particle(pi.v, law)))
    }
    def cond(ll: T => Double): Prob[T] = mapP(v => Particle(v, ll(v)))
    def empirical: Vector[T] = resample.particles.map(_.v)
  }

  case class Empirical[T](particles: Vector[Particle[T]]) extends Prob[T] {
    def draw: Particle[T] = {
      val lw = particles map (_.lw)
      val mx = lw reduce (math.max(_,_))
      val rw = lw map (lwi => math.exp(lwi - mx))
      val law = mx + math.log(rw.sum/(rw.length))
      val idx = bdist.Multinomial(DenseVector(rw.toArray)).draw
      Particle(particles(idx).v, law)
    }
  }

As before, if you are pasting code blocks into the REPL, you will need to use :paste mode to paste these two definitions together.

The essential structure is similar to that from the previous post, but with a few notable differences. Most fundamentally, we now require any concrete implementation to provide a draw method returning a single particle from the distribution. Like before, we are not worrying about purity of functional code here, and using a standard random number generator with a globally mutating state. We can define a mapP method (for “map particle”) using the new flatMap method on Particle, and then use that to define map and flatMap for Prob[_]. Crucially, draw is used to define flatMap without requiring a Cartesian product of distributions to be formed.

We add a draw method to our Empirical[_] implementation. This method is computationally intensive, so it will still be computationally problematic to chain several flatMaps together, but this will no longer be N^2 in memory utilisation. Note that again we carefully set the weight of the drawn particle so that its raw weight is the average of the raw weight of the empirical distribution. This is needed to propagate conditioning information correctly back through flatMaps. There is obviously some code duplication between the draw method on Empirical and the resample method on Prob, but I’m not sure it’s worth factoring out.

It is worth noting that neither flatMap nor cond triggers resampling, so the user of the library is now responsible for resampling when appropriate.

We can provide evidence that Prob[_] forms a monad just like we did Particle[_].

  implicit val probMonad = new Monad[Prob] {
    def pure[T](t: T): Prob[T] = Empirical(Vector(Particle(t, 0.0)))
    def flatMap[T,S](pt: Prob[T])(f: T => Prob[S]): Prob[S] = pt.flatMap(f)
    def tailRecM[T,S](t: T)(f: T => Prob[Either[T,S]]): Prob[S] = ???
  }

Again, we’ll want to be able to create a distribution from an unweighted collection of values.

  def unweighted[T](ts: Vector[T], lw: Double = 0.0): Prob[T] =
    Empirical(ts map (Particle(_, lw)))

We will again define an implementation for distributions with tractable likelihoods, which are therefore easy to condition on. They will typically also be easy to draw from efficiently, and we will use this fact, too.

  trait Dist[T] extends Prob[T] {
    def ll(obs: T): Double
    def ll(obs: Seq[T]): Double = obs map (ll) reduce (_+_)
    def fit(obs: Seq[T]): Prob[T] = mapP(v => Particle(v, ll(obs)))
    def fitQ(obs: Seq[T]): Prob[T] = Empirical(Vector(Particle(obs.head, ll(obs))))
    def fit(obs: T): Prob[T] = fit(List(obs))
    def fitQ(obs: T): Prob[T] = fitQ(List(obs))
  }

We can give implementations of this for a few standard distributions.

  case class Normal(mu: Double, v: Double)(implicit N: Int) extends Dist[Double] {
    lazy val particles = unweighted(bdist.Gaussian(mu, math.sqrt(v)).
      sample(N).toVector).particles
    def draw = Particle(bdist.Gaussian(mu, math.sqrt(v)).draw, 0.0)
    def ll(obs: Double) = bdist.Gaussian(mu, math.sqrt(v)).logPdf(obs)
  }

  case class Gamma(a: Double, b: Double)(implicit N: Int) extends Dist[Double] {
    lazy val particles = unweighted(bdist.Gamma(a, 1.0/b).
      sample(N).toVector).particles
    def draw = Particle(bdist.Gamma(a, 1.0/b).draw, 0.0)
    def ll(obs: Double) = bdist.Gamma(a, 1.0/b).logPdf(obs)
  }

  case class Poisson(mu: Double)(implicit N: Int) extends Dist[Int] {
    lazy val particles = unweighted(bdist.Poisson(mu).
      sample(N).toVector).particles
    def draw = Particle(bdist.Poisson(mu).draw, 0.0)
    def ll(obs: Int) = bdist.Poisson(mu).logProbabilityOf(obs)
  }

Note that we now have to provide an (efficient) draw method for each implementation, returning a single draw from the distribution as a Particle with a (raw) weight of 1 (log weight of 0).

We are done. It’s a few more lines of code than that from the previous post, but this is now much closer to something that could be used in practice to solve actual inference problems using a reasonable number of particles. But to do so we will need to be careful do avoid deep monadic binding. This is easiest to explain with a concrete example.

Using the SMC-based probability monad in practice

Monadic binding and applicative structure

As explained in the previous post, using Scala’s for-expressions for monadic binding gives a natural and elegant PPL for our probability monad “for free”. This is fine, and in general there is no reason why using it should lead to inefficient code. However, for this particular probability monad implementation, it turns out that deep monadic binding comes with a huge performance penalty. For a concrete example, consider the following specification, perhaps of a prior distribution over some independent parameters.

    val prior = for {
      x <- Normal(0,1)
      y <- Gamma(1,1)
      z <- Poisson(10)
    } yield (x,y,z)

Don’t paste that into the REPL – it will take an age to complete!

Again, I must emphasise that there is nothing wrong with this specification, and there is no reason in principle why such a specification can’t be computationally efficient – it’s just a problem for our particular probability monad. We can begin to understand the problem by thinking about how this will be de-sugared by the compiler. Roughly speaking, the above will de-sugar to the following nested flatMaps.

    val prior2 =
      Normal(0,1) flatMap {x =>
        Gamma(1,1) flatMap {y =>
          Poisson(10) map {z =>
            (x,y,z)}}}

Again, beware of pasting this into the REPL.

So, although written from top to bottom, the nesting is such that the flatMaps collapse from the bottom-up. The second flatMap (the first to collapse) isn’t such a problem here, as the Poisson has a O(1) draw method. But the result is an empirical distribution, which has an O(N) draw method. So the first flatMap (the second to collapse) is an O(N^2) operation. By extension, it’s easy to see that the computational cost of nested flatMaps will be exponential in the number of monadic binds. So, looking back at the for expression, the problem is that there are three <-. The last one isn’t a problem since it corresponds to a map, and the second last one isn’t a problem, since the final distribution is tractable with an O(1) draw method. The problem is the first <-, corresponding to a flatMap of one empirical distribution with respect to another. For our particular probability monad, it’s best to avoid these if possible.

The interesting thing to note here is that because the distributions are independent, there is no need for them to be sequenced. They could be defined in any order. In this case it makes sense to use the applicative structure implied by the monad.

Now, since we have told cats that Prob[_] is a monad, it can provide appropriate applicative methods for us automatically. In Cats, every monad is assumed to be also an applicative functor (which is true in Cartesian closed categories, and Cats implicitly assumes that all functors and monads are defined over CCCs). So we can give an alternative specification of the above prior using applicative composition.

 val prior3 = Applicative[Prob].tuple3(Normal(0,1), Gamma(1,1), Poisson(10))
// prior3: Wrapped.Prob[(Double, Double, Int)] = Empirical(Vector(Particle((-0.057088546468105204,0.03027578552505779,9),0.0), Particle((-0.43686658266043743,0.632210127012762,14),0.0), Particle((-0.8805715148936012,3.4799656228544706,4),0.0), Particle((-0.4371726407147289,0.0010707859994652403,12),0.0), Particle((2.0283297088320755,1.040984491158822,10),0.0), Particle((1.2971862986495886,0.189166705596747,14),0.0), Particle((-1.3111333817551083,0.01962422606642761,9),0.0), Particle((1.6573851896142737,2.4021836368401415,9),0.0), Particle((-0.909927220984726,0.019595551644771683,11),0.0), Particle((0.33888133893822464,0.2659823344145805,10),0.0), Particle((-0.3300797295729375,3.2714740256437667,10),0.0), Particle((-1.8520554352884224,0.6175322756460341,10),0.0), Particle((0.541156780497547...

This one is mathematically equivalent, but safe to paste into your REPL, as it does not involve deep monadic binding, and can be used whenever we want to compose together independent components of a probabilistic program. Note that “tupling” is not the only possibility – Cats provides a range of functions for manipulating applicative values.

This is one way to avoid deep monadic binding, but another strategy is to just break up a large for expression into separate smaller for expressions. We can examine this strategy in the context of state-space modelling.

Particle filtering for a non-linear state-space model

We can now re-visit the DGLM example from the previous post. We began by declaring some observations and a prior.

    val data = List(2,1,0,2,3,4,5,4,3,2,1)
// data: List[Int] = List(2, 1, 0, 2, 3, 4, 5, 4, 3, 2, 1)

    val prior = for {
      w <- Gamma(1, 1)
      state0 <- Normal(0.0, 2.0)
    } yield (w, List(state0))
// prior: Wrapped.Prob[(Double, List[Double])] = Empirical(Vector(Particle((4.220683377724395,List(0.37256749723762683)),0.0), Particle((0.4436668049925418,List(-1.0053578391265572)),0.0), Particle((0.9868899648436931,List(-0.6985099310193449)),0.0), Particle((0.13474375773634908,List(0.9099291736792412)),0.0), Particle((1.9654021747685184,List(-0.042127103727998175)),0.0), Particle((0.21761202474220223,List(1.1074616830012525)),0.0), Particle((0.31037163527711015,List(0.9261849914020324)),0.0), Particle((1.672438830781466,List(0.01678529855289384)),0.0), Particle((0.2257151759143097,List(2.5511304854128354)),0.0), Particle((0.3046489890769499,List(3.2918304533361398)),0.0), Particle((1.5115941814057159,List(-1.633612165168878)),0.0), Particle((1.4185906813831506,List(-0.8460922678989864))...

Looking carefully at the for-expression, there are just two <-, and the distribution on the RHS of the flatMap is tractable, so this is just O(N). So far so good.

Next, let’s look at the function to add a time point, which previously looked something like the following.

    def addTimePointSIS(current: Prob[(Double, List[Double])],
      obs: Int): Prob[(Double, List[Double])] = {
      println(s"Conditioning on observation: $obs")
      for {
        tup <- current
        (w, states) = tup
        os = states.head
        ns <- Normal(os, w)
        _ <- Poisson(math.exp(ns)).fitQ(obs)
      } yield (w, ns :: states)
    }
// addTimePointSIS: (current: Wrapped.Prob[(Double, List[Double])], obs: Int)Wrapped.Prob[(Double, List[Double])]

Recall that our new probability monad does not automatically trigger resampling, so applying this function in a fold will lead to a simple sampling importance sampling (SIS) particle filter. Typically, the bootstrap particle filter includes resampling after each time point, giving a special case of a sampling importance resampling (SIR) particle filter, which we could instead write as follows.

    def addTimePointSimple(current: Prob[(Double, List[Double])],
      obs: Int): Prob[(Double, List[Double])] = {
      println(s"Conditioning on observation: $obs")
      val updated = for {
        tup <- current
        (w, states) = tup
        os = states.head
        ns <- Normal(os, w)
        _ <- Poisson(math.exp(ns)).fitQ(obs)
      } yield (w, ns :: states)
      updated.resample
    }
// addTimePointSimple: (current: Wrapped.Prob[(Double, List[Double])], obs: Int)Wrapped.Prob[(Double, List[Double])]

This works fine, but we can see that there are three <- in this for expression. This leads to a flatMap with an empirical distribution on the RHS, and hence is O(N^2). But this is simple enough to fix, by separating the updating process into separate “predict” and “update” steps, which is how people typically formulate particle filters for state-space models, anyway. Here we could write that as

    def addTimePoint(current: Prob[(Double, List[Double])],
      obs: Int): Prob[(Double, List[Double])] = {
      println(s"Conditioning on observation: $obs")
      val predict = for {
        tup <- current
        (w, states) = tup
        os = states.head
        ns <- Normal(os, w)
      }
      yield (w, ns :: states)
      val updated = for {
        tup <- predict
        (w, states) = tup
        st = states.head
        _ <- Poisson(math.exp(st)).fitQ(obs)
      } yield (w, states)
      updated.resample
    }
// addTimePoint: (current: Wrapped.Prob[(Double, List[Double])], obs: Int)Wrapped.Prob[(Double, List[Double])]

By breaking the for expression into two: the first for the “predict” step and the second for the “update” (conditioning on the observation), we get two O(N) operations, which for large N is clearly much faster. We can then run the filter by folding over the observations.

  import breeze.stats.{meanAndVariance => meanVar}
// import breeze.stats.{meanAndVariance=>meanVar}

  val mod = data.foldLeft(prior)(addTimePoint(_,_)).empirical
// Conditioning on observation: 2
// Conditioning on observation: 1
// Conditioning on observation: 0
// Conditioning on observation: 2
// Conditioning on observation: 3
// Conditioning on observation: 4
// Conditioning on observation: 5
// Conditioning on observation: 4
// Conditioning on observation: 3
// Conditioning on observation: 2
// Conditioning on observation: 1
// mod: Vector[(Double, List[Double])] = Vector((0.24822528144246606,List(0.06290285371838457, 0.01633338109272575, 0.8997103339551227, 1.5058726341571411, 1.0579925693609091, 1.1616536515200064, 0.48325623593870665, 0.8457351097543767, -0.1988290999293708, -0.4787511341321954, -0.23212497417019512, -0.15327432440577277)), (1.111430233331792,List(0.6709342824443849, 0.009092797044165657, -0.13203367846117453, 0.4599952735399485, 1.3779288637042504, 0.6176597963402879, 0.6680455419800753, 0.48289163013446945, -0.5994001698510807, 0.4860969602653898, 0.10291798193078927, 1.2878325765987266)), (0.6118925941009055,List(0.6421161986636132, 0.679470360928868, 1.0552459559203342, 1.200835166087372, 1.3690372269589233, 1.8036766847282912, 0.6229883551656629, 0.14872642198313774, -0.122700856878725...

  meanVar(mod map (_._1)) // w
// res0: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(0.2839184023932576,0.07391602428256917,2000)

  meanVar(mod map (_._2.reverse.head)) // initial state
// res1: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(0.26057368528422714,0.4802810202354611,2000)

  meanVar(mod map (_._2.head)) // final state
// res2: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(0.5448036669181697,0.28293080584600894,2000)

Summary and conclusions

Here we have just done some minor tidying up of the rather naive probability monad from the previous post to produce an SMC-based probability monad with improved performance characteristics. Again, we get an embedded probabilistic programming language “for free”. Although the language itself is very flexible, allowing us to construct more-or-less arbitrary probabilistic programs for Bayesian inference problems, we saw that a bug/feature of this particular inference algorithm is that care must be taken to avoid deep monadic binding if reasonable performance is to be obtained. In most cases this is simple to achieve by using applicative composition or by breaking up large for expressions.

There are still many issues and inefficiencies associated with this PPL. In particular, if the main intended application is to state-space models, it would make more sense to tailor the algorithms and implementations to exactly that case. OTOH, if the main concern is a generic PPL, then it would make sense to make the PPL independent of the particular inference algorithm. These are both potential topics for future posts.

Software

  • min-ppl2 – code associated with this blog post
  • Rainier – a more efficient PPL with similar syntax
  • monad-bayes – a Haskell library exploring related ideas

Write your own general-purpose monadic probabilistic programming language from scratch in 50 lines of (Scala) code

Background

In May I attended a great workshop on advances and challenges in machine learning languages at the CMS in Cambridge. There was an a good mix of people from different disciplines, and a bit of a theme around probabilistic programming. The workshop schedule includes links to many of the presentations, and is generally worth browsing. In particular, it includes a link to the slides for my presentation on a compositional approach to scalable Bayesian computation and probabilistic programming. I’ve given a few talks on this kind of thing over the last couple of years, at Newcastle, at the Isaac Newton Institute in Cambridge (twice), and at CIRM in France. But I think I explained things best at this workshop at the CMS, though my impression could partly have been a reflection of the more interested and relevant audience. In the talk I started with a basic explanation of why ideas from category theory and functional programming can help to solve problems in statistical computing in a more composable and scalable way, before moving on to discuss probability monads and their fundamental connection to probabilistic programming. The take home message from the talk is that if you have a generic inference algorithm, expressing the logic in the context of probability monads can give you an embedded probabilistic programming language (PPL) for that inference algorithm essentially “for free”.

So, during my talk I said something a little fool-hardy. I can’t remember my exact words, but while presenting the idea behind an SMC-based probability monad I said something along the lines of “one day I will write a blog post on how to write a probabilistic programming language from scratch in 50 lines of code, and this is how I’ll do it“! Rather predictably (with hindsight), immediately after my talk about half a dozen people all pleaded with me to urgently write the post! I’ve been a little busy since then, but now that things have settled down a little for the summer, I’ve some time to think and code, so here is that post.

Introduction

The idea behind this post is to show that, if you think about the problem in the right way, and use a programming language with syntactic support for monadic composition, then producing a flexible, general, compositional, embedded domain specific language (DSL) for probabilistic programming based on a given generic inference algorithm is no more effort than hard-coding two or three illustrative examples. You would need to code up two or three examples for a paper anyway, but providing a PPL is way more useful. There is also an interesting converse to this, which is that if you can’t easily produce a PPL for your “general” inference algorithm, then perhaps it isn’t quite as “general” as you thought. I’ll try to resist exploring that here…

To illustrate these principles I want to develop a fairly minimal PPL, so that the complexities of the inference algorithm don’t hide the simplicity of the PPL embedding. Importance sampling with resampling is probably the simplest useful generic Bayesian inference algorithm to implement, so that’s what I’ll use. Note that there are many limitations of the approach that I will adopt, which will make it completely unsuitable for “real” problems. In particular, this implementation is: inefficient, in terms of both compute time and memory usage, statistically inefficient for deep nesting and repeated conditioning, due to the particle degeneracy problem, specific to a particular probability monad, strictly evaluated, impure (due to mutation of global random number state), etc. All of these things are easily fixed, but all at the expense of greater abstraction, complexity and lines of code. I’ll probably discuss some of these generalisations and improvements in future posts, but for this post I want to keep everything as short and simple as practical. It’s also worth mentioning that there is nothing particularly original here. Many people have written about monadic embedded PPLs, and several have used an SMC-based monad for illustration. I’ll give some pointers to useful further reading at the end.

The language, in 50 lines of code

Without further ado, let’s just write the PPL. I’m using plain Scala, with just a dependency on the Breeze scientific library, which I’m going to use for simulating random numbers from standard distributions, and evaluation of their log densities. I have a directory of materials associated with this post in a git repo. This post is derived from an executable tut document (so you know it works), which can be found here. If you just want to follow along copying code at the command prompt, just run sbt from an empty or temp directory, and copy the following to spin up a Scala console with the Breeze dependency:

set libraryDependencies += "org.scalanlp" %% "breeze" % "1.0-RC4"
set libraryDependencies += "org.scalanlp" %% "breeze-natives" % "1.0-RC4"
set scalaVersion := "2.13.0"
console

We start with a couple of Breeze imports

import breeze.stats.{distributions => bdist}
import breeze.linalg.DenseVector

which are not strictly necessary, but clean up the subsequent code. We are going to use a set of weighted particles to represent a probability distribution empirically, so we’ll start by defining an appropriate ADT for these:

implicit val numParticles = 300

case class Particle[T](v: T, lw: Double) { // value and log-weight
  def map[S](f: T => S): Particle[S] = Particle(f(v), lw)
}

We also include a map method for pushing a particle through a transformation, and a default number of particles for sampling and resampling. 300 particles are enough for illustrative purposes. Ideally it would be good to increase this for more realistic experiments. We can use this particle type to build our main probability monad as follows.

trait Prob[T] {
  val particles: Vector[Particle[T]]
  def map[S](f: T => S): Prob[S] = Empirical(particles map (_ map f))
  def flatMap[S](f: T => Prob[S]): Prob[S] = {
    Empirical((particles map (p => {
      f(p.v).particles.map(psi => Particle(psi.v, p.lw + psi.lw))
    })).flatten).resample
  }
  def resample(implicit N: Int): Prob[T] = {
    val lw = particles map (_.lw)
    val mx = lw reduce (math.max(_,_))
    val rw = lw map (lwi => math.exp(lwi - mx))
    val law = mx + math.log(rw.sum/(rw.length))
    val ind = bdist.Multinomial(DenseVector(rw.toArray)).sample(N)
    val newParticles = ind map (i => particles(i))
    Empirical(newParticles.toVector map (pi => Particle(pi.v, law)))
  }
  def cond(ll: T => Double): Prob[T] =
    Empirical(particles map (p => Particle(p.v, p.lw + ll(p.v))))
  def empirical: Vector[T] = resample.particles.map(_.v)
}

case class Empirical[T](particles: Vector[Particle[T]]) extends Prob[T]

Note that if you are pasting into the Scala REPL you will need to use :paste mode for this. So Prob[_] is our base probability monad trait, and Empirical[_] is our simplest implementation, which is just a collection of weighted particles. The method flatMap forms the naive product of empirical measures and then resamples in order to stop an explosion in the number of particles. There are two things worth noting about the resample method. The first is that the log-sum-exp trick is being used to avoid overflow and underflow when the log weights are exponentiated. The second is that although the method returns an equally weighted set of particles, the log weights are all set in order that the average raw weight of the output set matches the average raw weight of the input set. This is a little tricky to explain, but it turns out to be necessary in order to correctly propagate conditioning information back through multiple monadic binds (flatMaps). The cond method allows conditioning of a distribution using an arbitrary log-likelihood. It is included for comparison with some other implementations I will refer to later, but we won’t actually be using it, so we could save two lines of code here if necessary. The empirical method just extracts an unweighted set of values from a distribution for subsequent analysis.

It will be handy to have a function to turn a bunch of unweighted particles into a set of particles with equal weights (a sort-of inverse of the empirical method just described), so we can define that as follows.

def unweighted[T](ts: Vector[T], lw: Double = 0.0): Prob[T] =
  Empirical(ts map (Particle(_, lw)))

Probabilistic programming is essentially trivial if we only care about forward sampling. But interesting PPLs allow us to condition on observed values of random variables. In the context of SMC, this is simplest when the distribution being conditioned has a tractable log-likelihood. So we can now define an extension of our probability monad for distributions with a tractable log-likelihood, and define a bunch of convenient conditioning (or “fitting”) methods using it.

trait Dist[T] extends Prob[T] {
  def ll(obs: T): Double
  def ll(obs: Seq[T]): Double = obs map (ll) reduce (_+_)
  def fit(obs: Seq[T]): Prob[T] =
    Empirical(particles map (p => Particle(p.v, p.lw + ll(obs))))
  def fitQ(obs: Seq[T]): Prob[T] = Empirical(Vector(Particle(obs.head, ll(obs))))
  def fit(obs: T): Prob[T] = fit(List(obs))
  def fitQ(obs: T): Prob[T] = fitQ(List(obs))
}

The only unimplemented method is ll(). The fit method re-weights a particle set according to the observed log-likelihood. For convenience, it also returns a particle cloud representing the posterior-predictive distribution of an iid value from the same distribution. This is handy, but comes at the expense of introducing an additional particle cloud. So, if you aren’t interested in the posterior predictive, you can avoid this cost by using the fitQ method (for “fit quick”), which doesn’t return anything useful. We’ll see examples of this in practice, shortly. Note that the fitQ methods aren’t strictly required for our “minimal” PPL, so we can save a couple of lines by omitting them if necessary. Similarly for the variants which allow conditioning on a collection of iid observations from the same distribution.

At this point we are essentially done. But for convenience, we can define a few standard distributions to help get new users of our PPL started. Of course, since the PPL is embedded, it is trivial to add our own additional distributions later.

case class Normal(mu: Double, v: Double)(implicit N: Int) extends Dist[Double] {
  lazy val particles = unweighted(bdist.Gaussian(mu, math.sqrt(v)).sample(N).toVector).particles
  def ll(obs: Double) = bdist.Gaussian(mu, math.sqrt(v)).logPdf(obs) }

case class Gamma(a: Double, b: Double)(implicit N: Int) extends Dist[Double] {
  lazy val particles = unweighted(bdist.Gamma(a, 1.0/b).sample(N).toVector).particles
  def ll(obs: Double) = bdist.Gamma(a, 1.0/b).logPdf(obs) }

case class Poisson(mu: Double)(implicit N: Int) extends Dist[Int] {
  lazy val particles = unweighted(bdist.Poisson(mu).sample(N).toVector).particles
  def ll(obs: Int) = bdist.Poisson(mu).logProbabilityOf(obs) }

Note that I’ve parameterised the Normal and Gamma the way that statisticians usually do, and not the way they are usually parameterised in scientific computing libraries (such as Breeze).

That’s it! This is a complete, general-purpose, composable, monadic PPL, in 50 (actually, 48, and fewer still if you discount trailing braces) lines of code. Let’s now see how it works in practice.

Examples

Normal random sample

We’ll start off with just about the simplest slightly interesting example I can think of: Bayesian inference for the mean and variance of a normal distribution from a random sample.

import breeze.stats.{meanAndVariance => meanVar}
// import breeze.stats.{meanAndVariance=>meanVar}

val mod = for {
  mu <- Normal(0, 100)
  tau <- Gamma(1, 0.1)
  _ <- Normal(mu, 1.0/tau).fitQ(List(8.0,9,7,7,8,10))
} yield (mu,tau)
// mod: Wrapped.Prob[(Double, Double)] = Empirical(Vector(Particle((8.718127116254472,0.93059589932682),-15.21683812389373), Particle((7.977706390420308,1.1575288208065433),-15.21683812389373), Particle((7.977706390420308,1.1744750937611985),-15.21683812389373), Particle((7.328100552769214,1.1181787982959164),-15.21683812389373), Particle((7.977706390420308,0.8283737237370494),-15.21683812389373), Particle((8.592847414557049,2.2934836446009026),-15.21683812389373), Particle((8.718127116254472,1.498741032928539),-15.21683812389373), Particle((8.592847414557049,0.2506065368748732),-15.21683812389373), Particle((8.543283880264225,1.127386759627675),-15.21683812389373), Particle((7.977706390420308,1.3508728798704925),-15.21683812389373), Particle((7.977706390420308,1.1134430556990933),-15.2168...

val modEmp = mod.empirical
// modEmp: Vector[(Double, Double)] = Vector((7.977706390420308,0.8748006833362748), (6.292345096890432,0.20108091703626174), (9.15330820843396,0.7654238730107492), (8.960935105658741,1.027712984079369), (7.455292602273359,0.49495749079351836), (6.911716909394562,0.7739749058662421), (6.911716909394562,0.6353785792877397), (7.977706390420308,1.1744750937611985), (7.977706390420308,1.1134430556990933), (8.718127116254472,1.166399872049532), (8.763777227034538,1.0468304705769353), (8.718127116254472,0.93059589932682), (7.328100552769214,1.6166695922250236), (8.543283880264225,0.4689300351248357), (8.543283880264225,2.0028918490755094), (7.536025958690963,0.6282318170458533), (7.328100552769214,1.6166695922250236), (7.049843463553113,0.20149378088848635), (7.536025958690963,2.3565657669819897...

meanVar(modEmp map (_._1)) // mu
// res0: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(8.311171010932343,0.4617800639333532,300)

meanVar(modEmp map (_._2)) // tau
// res1: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(0.940762723934599,0.23641881704888842,300)

Note the use of the empirical method to turn the distribution into an unweighted set of particles for Monte Carlo analysis. Anyway, the main point is that the syntactic sugar for monadic binds (flatMaps) provided by Scala’s for-expressions (similar to do-notation in Haskell) leads to readable code not so different to that in well-known general-purpose PPLs such as BUGS, JAGS, or Stan. There are some important differences, however. In particular, the embedded DSL has probabilistic programs as regular values in the host language. These may be manipulated and composed like other values. This makes this probabilistic programming language more composable than the aforementioned languages, which makes it much simpler to build large, complex probabilistic programs from simpler, well-tested, components, in a scalable way. That is, this PPL we have obtained “for free” is actually in many ways better than most well-known PPLs.

Noisy measurements of a count

Here we’ll look at the problem of inference for a discrete count given some noisy iid continuous measurements of it.

val mod = for {
  count <- Poisson(10)
  tau <- Gamma(1, 0.1)
  _ <- Normal(count, 1.0/tau).fitQ(List(4.2,5.1,4.6,3.3,4.7,5.3))
} yield (count, tau)
// mod: Wrapped.Prob[(Int, Double)] = Empirical(Vector(Particle((5,4.488795220669575),-11.591037521513753), Particle((5,1.7792314573063672),-11.591037521513753), Particle((5,2.5238021156137673),-11.591037521513753), Particle((4,3.280754333896923),-11.591037521513753), Particle((5,2.768438569482849),-11.591037521513753), Particle((4,1.3399975573518912),-11.591037521513753), Particle((5,1.1792835858615431),-11.591037521513753), Particle((5,1.989491156206883),-11.591037521513753), Particle((4,0.7825254987152054),-11.591037521513753), Particle((5,2.7113936834028793),-11.591037521513753), Particle((5,3.7615196800240387),-11.591037521513753), Particle((4,1.6833300961124709),-11.591037521513753), Particle((5,2.749183220798113),-11.591037521513753), Particle((5,2.1074062883430202),-11.591037521513...

val modEmp = mod.empirical
// modEmp: Vector[(Int, Double)] = Vector((4,3.243786594839479), (4,1.5090869158886693), (4,1.280656912383482), (5,2.0616356908358195), (5,3.475433097869503), (5,1.887582611202514), (5,2.8268877720514745), (5,0.9193261688050818), (4,1.7063629502805908), (5,2.116414832864841), (5,3.775508828984636), (5,2.6774941123762814), (5,2.937859946593459), (5,1.2047689975166402), (5,2.5658806161572656), (5,1.925890364268593), (4,1.0194093176888832), (5,1.883288825936725), (5,4.9503779454422965), (5,0.9045613180858916), (4,1.5795027943928661), (5,1.925890364268593), (5,2.198539449287062), (5,1.791363956348445), (5,0.9853760689818026), (4,1.6541388923071607), (5,2.599899960899971), (4,1.8904423810277957), (5,3.8983183765907836), (5,1.9242319515895554), (5,2.8268877720514745), (4,1.772120802027519), (5,2...

meanVar(modEmp map (_._1.toDouble)) // count
// res2: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(4.670000000000004,0.23521739130434777,300)

meanVar(modEmp map (_._2)) // tau
// res3: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(1.9678279101913874,0.9603971613375548,300)

I’ve included this mainly as an example of inference for a discrete-valued parameter. There are people out there who will tell you that discrete parameters are bad/evil/impossible. This isn’t true – discrete parameters are cool!

Linear model

Because our PPL is embedded, we can take full advantage of the power of the host programming language to build our models. Let’s explore this in the context of Bayesian estimation of a linear model. We’ll start with some data.

val x = List(1.0,2,3,4,5,6)
// x: List[Double] = List(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)

val y = List(3.0,2,4,5,5,6)
// y: List[Double] = List(3.0, 2.0, 4.0, 5.0, 5.0, 6.0)

val xy = x zip y
// xy: List[(Double, Double)] = List((1.0,3.0), (2.0,2.0), (3.0,4.0), (4.0,5.0), (5.0,5.0), (6.0,6.0))

Now, our (simple) linear regression model will be parameterised by an intercept, alpha, a slope, beta, and a residual variance, v. So, for convenience, let’s define an ADT representing a particular linear model.

case class Param(alpha: Double, beta: Double, v: Double)
// defined class Param

Now we can define a prior distribution over models as follows.

val prior = for {
  alpha <- Normal(0,10)
  beta <- Normal(0,4)
  v <- Gamma(1,0.1)
} yield Param(alpha, beta, v)
// prior: Wrapped.Prob[Param] = Empirical(Vector(Particle(Param(-2.392517550699654,-3.7516090283880095,1.724680963054379),0.0), Particle(Param(7.60982717067903,-1.4318199629361292,2.9436745225038545),0.0), Particle(Param(-1.0281832158124837,-0.2799562317845073,4.05125312048092),0.0), Particle(Param(-1.0509321093485073,-2.4733837587060448,0.5856868459456287),0.0), Particle(Param(7.678898742733517,0.15616204936412104,5.064540017623097),0.0), Particle(Param(-3.392028985658713,-0.694412176170572,7.452625596437611),0.0), Particle(Param(3.0310535934425324,-2.97938526497514,2.138446100857938),0.0), Particle(Param(3.016959696424399,1.3370878561954143,6.18957854813488),0.0), Particle(Param(2.6956505371497066,1.058845844793446,5.257973123790336),0.0), Particle(Param(1.496225540527873,-1.573936445746...

Since our language doesn’t include any direct syntactic support for fitting regression models, we can define our own function for conditioning a distribution over models on a data point, which we can then apply to our prior as a fold over the available data.

def addPoint(current: Prob[Param], obs: (Double, Double)): Prob[Param] = for {
    p <- current
    (x, y) = obs
    _ <- Normal(p.alpha + p.beta * x, p.v).fitQ(y)
  } yield p
// addPoint: (current: Wrapped.Prob[Param], obs: (Double, Double))Wrapped.Prob[Param]

val mod = xy.foldLeft(prior)(addPoint(_,_)).empirical
// mod: Vector[Param] = Vector(Param(1.4386051853067798,0.8900831186754122,4.185564696221981), Param(0.5530582357040271,1.1296886766045509,3.468527573093037), Param(0.6302560079049571,0.9396563485293532,3.7044543917875927), Param(3.68291303096638,0.4781372802435529,5.151665328789926), Param(3.016959696424399,0.4438016738989412,1.9988914122633519), Param(3.016959696424399,0.4438016738989412,1.9988914122633519), Param(0.6302560079049571,0.9396563485293532,3.7044543917875927), Param(0.6302560079049571,0.9396563485293532,3.7044543917875927), Param(3.68291303096638,0.4781372802435529,5.151665328789926), Param(3.016959696424399,0.4438016738989412,1.9988914122633519), Param(0.6302560079049571,0.9396563485293532,3.7044543917875927), Param(0.6302560079049571,0.9396563485293532,3.7044543917875927), ...

meanVar(mod map (_.alpha))
// res4: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(1.5740812481283812,1.893684802867127,300)

meanVar(mod map (_.beta))
// res5: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(0.7690238868623273,0.1054479268115053,300)

meanVar(mod map (_.v))
// res6: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(3.5240853748668695,2.793386340338213,300)

We could easily add syntactic support to our language to enable the fitting of regression-style models, as is done in Rainier, of which more later.

Dynamic generalised linear model

The previous examples have been fairly simple, so let’s finish with something a bit less trivial. Our language is quite flexible enough to allow the analysis of a dynamic generalised linear model (DGLM). Here we’ll fit a Poisson DGLM with a log-link and a simple Brownian state evolution. More complex models are more-or-less similarly straightforward. The model is parameterised by an initial state, state0, and and evolution variance, w.

val data = List(2,1,0,2,3,4,5,4,3,2,1)
// data: List[Int] = List(2, 1, 0, 2, 3, 4, 5, 4, 3, 2, 1)

val prior = for {
  w <- Gamma(1, 1)
  state0 <- Normal(0.0, 2.0)
} yield (w, List(state0))
// prior: Wrapped.Prob[(Double, List[Double])] = Empirical(Vector(Particle((0.12864918092587044,List(-2.862479260552014)),0.0), Particle((1.1706344622093179,List(1.6138397233532091)),0.0), Particle((0.757288087950638,List(-0.3683499919402798)),0.0), Particle((2.755201217523856,List(-0.6527488751780317)),0.0), Particle((0.7535085397802043,List(0.5135562407906502)),0.0), Particle((1.1630726564525629,List(0.9703146201262348)),0.0), Particle((1.0080345715326213,List(-0.375686732266234)),0.0), Particle((4.603723117526974,List(-1.6977366375222938)),0.0), Particle((0.2870669117815037,List(2.2732160435099433)),0.0), Particle((2.454675218313211,List(-0.4148287542786906)),0.0), Particle((0.3612534201761152,List(-1.0099270904161748)),0.0), Particle((0.29578453393473114,List(-2.4938128878051966)),0.0)...

We can define a function to create a new hidden state, prepend it to the list of hidden states, and condition on the observed value at that time point as follows.

def addTimePoint(current: Prob[(Double, List[Double])],
  obs: Int): Prob[(Double, List[Double])] = for {
  tup <- current
  (w, states) = tup
  os = states.head
  ns <- Normal(os, w)
  _ <- Poisson(math.exp(ns)).fitQ(obs)
} yield (w, ns :: states)
// addTimePoint: (current: Wrapped.Prob[(Double, List[Double])], obs: Int)Wrapped.Prob[(Double, List[Double])]

We then run our (augmented state) particle filter as a fold over the time series.

val mod = data.foldLeft(prior)(addTimePoint(_,_)).empirical
// mod: Vector[(Double, List[Double])] = Vector((0.053073252551193446,List(0.8693030057529023, 1.2746526177834938, 1.020307245610461, 1.106341696651584, 1.070777529635013, 0.8749041525303247, 0.9866999164354662, 0.4082577920509255, 0.06903234462140699, -0.018835642776197814, -0.16841912034400547, -0.08919045681401294)), (0.0988871875952762,List(-0.24241948109998607, 0.09321618969352086, 0.9650532206325375, 1.1738734442767293, 1.2272325310228442, 0.9791695328246326, 0.5576319082578128, -0.0054280215024367084, 0.4256621012454391, 0.7486862644576158, 0.8193517409118243, 0.5928750312493785)), (0.16128799384962295,List(-0.30371187329667104, -0.3976854602292066, 0.5869357473774455, 0.9881090696832543, 1.2095181380307558, 0.7211231597865506, 0.8085486452269925, 0.2664373341459165, -0.627344024142...

meanVar(mod map (_._1)) // w
// res7: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(0.29497487517435844,0.0831412016262515,300)

meanVar(mod map (_._2.reverse.head)) // state0 (initial state)
// res8: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(0.04617218427664018,0.372844704533101,300)

meanVar(mod map (_._2.head)) // stateN (final state)
// res9: breeze.stats.meanAndVariance.MeanAndVariance = MeanAndVariance(0.4937178761565612,0.2889287607470016,300)

Summary, conclusions, and further reading

So, we’ve seen how we can build a fully functional, general-purpose, compositional, monadic PPL from scratch in 50 lines of code, and we’ve seen how we can use it to solve real, analytically intractable Bayesian inference problems of non-trivial complexity. Of course, there are many limitations to using exactly this PPL implementation in practice. The algorithm becomes intolerably slow for deeply nested models, and uses unreasonably large amounts of RAM for large numbers of particles. It also suffers from a particle degeneracy problem if there are too many conditioning events. But it is important to understand that these are all deficiencies of the naive inference algorithm used, not the PPL itself. The PPL is flexible and compositional and can be used to build models of arbitrary size and complexity – it just needs to be underpinned by a better, more efficient, inference algorithm. Rainier is a Scala library I’ve blogged about previously which uses a very similar PPL to the one described here, but is instead underpinned by a fast, efficient, HMC algorithm. With my student Jonny Law, we have recently arXived a paper on Functional probabilistic programming for scalable Bayesian modelling, discussing some of these issues, and exploring the compositional nature of monadic PPLs (somewhat glossed over in this post).

Since the same PPL can be underpinned by different inference algorithms encapsulated as probability monads, an obvious question is whether it is possible to abstract the PPL away from the inference algorithm implementation. Of course, the answer is “yes”, and this has been explored to great effect in papers such as Practical probabilistic programming with monads and Functional programming for modular Bayesian inference. Note that they use the cond approach to conditioning, which looks a bit unwieldy, but is equivalent to fitting. As well as allowing alternative inference algorithms to be applied to the same probabilistic program, it also enables the composing of inference algorithms – for example, composing a MH algorithm with an SMC algorithm in order to get a PMMH algorithm. The ideas are implemented in an embedded DSL for Haskell, monad-bayes. If you are not used to Haskell, the syntax will probably seem a bit more intimidating than Scala’s, but the semantics are actually quite similar, with the main semantic difference being that Scala is strictly evaluated by default, whereas Haskell is lazily evaluated by default. Both languages support both lazy and strict evaluation – the difference relates simply to default behaviour, but is important nevertheless.

Papers

Software

  • min-ppl – code associated with this blog post
  • Rainier – a more efficient PPL with similar syntax
  • monad-bayes – a Haskell library exploring related ideas

Bayesian hierarchical modelling with Rainier

Introduction

In the previous post I gave a brief introduction to Rainier, a new HMC-based probabilistic programming library/DSL for Scala. In that post I assumed that people were using the latest source version of the library. Since then, version 0.1.1 of the library has been released, so in this post I will demonstrate use of the released version of the software (using the binaries published to Sonatype), and will walk through a slightly more interesting example – a dynamic linear state space model with unknown static parameters. This is similar to, but slightly different from, the DLM example in the Rainier library. So to follow along with this post, all that is required is SBT.

An interactive session

First run SBT from an empty directory, and paste the following at the SBT prompt:

set libraryDependencies  += "com.stripe" %% "rainier-plot" % "0.1.1"
set scalaVersion := "2.12.4"
console

This should give a Scala REPL with appropriate dependencies (rainier-plot has all of the relevant transitive dependencies). We’ll begin with some imports, and then simulating some synthetic data from a dynamic linear state space model with an AR(1) latent state and Gaussian noise on the observations.

import com.stripe.rainier.compute._
import com.stripe.rainier.core._
import com.stripe.rainier.sampler._

implicit val rng = ScalaRNG(1)
val n = 60 // number of observations/time points
val mu = 3.0 // AR(1) mean
val a = 0.95 // auto-regressive parameter
val sig = 0.2 // AR(1) SD
val sigD = 3.0 // observational SD
val state = Stream.
  iterate(0.0)(x => mu + (x - mu) * a + sig * rng.standardNormal).
  take(n).toVector
val obs = state.map(_ + sigD * rng.standardNormal)

Now we have some synthetic data, let’s think about building a probabilistic program for this model. Start with a prior.

case class Static(mu: Real, a: Real, sig: Real, sigD: Real)
val prior = for {
  mu <- Normal(0, 10).param
  a <- Normal(1, 0.1).param
  sig <- Gamma(2,1).param
  sigD <- Gamma(2,2).param
  sp <- Normal(0, 50).param
} yield (Static(mu, a, sig, sigD), List(sp))

Note the use of a case class for wrapping the static parameters. Next, let’s define a function to add a state and associated observation to an existing model.

def addTimePoint(current: RandomVariable[(Static, List[Real])],
                     datum: Double) = for {
  tup <- current
  static = tup._1
  states = tup._2
  os = states.head
  ns <- Normal(((Real.one - static.a) * static.mu) + (static.a * os),
                 static.sig).param
  _ <- Normal(ns, static.sigD).fit(datum)
} yield (static, ns :: states)

Given this, we can generate the probabilistic program for our model as a fold over the data initialised with the prior.

val fullModel = obs.foldLeft(prior)(addTimePoint(_, _))

If we don’t want to keep samples for all of the variables, we can focus on the parameters of interest, wrapping the results in a Map for convenient sampling and plotting.

val model = for {
  tup <- fullModel
  static = tup._1
  states = tup._2
} yield
  Map("mu" -> static.mu,
  "a" -> static.a,
  "sig" -> static.sig,
  "sigD" -> static.sigD,
  "SP" -> states.reverse.head)

We can sample with

val out = model.sample(HMC(3), 100000, 10000 * 500, 500)

(this will take several minutes) and plot some diagnostics with

import com.cibo.evilplot.geometry.Extent
import com.stripe.rainier.plot.EvilTracePlot._

val truth = Map("mu" -> mu, "a" -> a, "sigD" -> sigD,
  "sig" -> sig, "SP" -> state(0))
render(traces(out, truth), "traceplots.png",
  Extent(1200, 1400))
render(pairs(out, truth), "pairs.png")

This generates the following diagnostic plots:

Everything looks good.

Summary

Rainier is a monadic embedded DSL for probabilistic programming in Scala. We can use standard functional combinators and for-expressions for building models to sample, and then run an efficient HMC algorithm on the resulting probability monad in order to obtain samples from the posterior distribution of the model.

See the Rainier repo for further details.

A scalable particle filter in Scala

Introduction

Many modern algorithms in computational Bayesian statistics have at their heart a particle filter or some other sequential Monte Carlo (SMC) procedure. In this blog I’ve discussed particle MCMC algorithms which use a particle filter in the inner-loop in order to compute a (noisy, unbiased) estimate of the marginal likelihood of the data. These algorithms are often very computationally intensive, either because the forward model used to propagate the particles is expensive, or because the likelihood associated with each particle/observation is expensive (or both). In this case it is desirable to parallelise the particle filter to run on all available cores of a machine, or in some cases, it would even be desirable to distribute the the particle filter computation across a cluster of machines.

Parallelisation is difficult when using the conventional imperative programming languages typically used in scientific and statistical computing, but is much easier using modern functional languages such as Scala. In fact, in languages such as Scala it is possible to describe algorithms at a higher level of abstraction, so that exactly the same algorithm can run in serial, run in parallel across all available cores on a single machine, or run in parallel across a cluster of machines, all without changing any code. Doing so renders parallelisation a non-issue. In this post I’ll talk through how to do this for a simple bootstrap particle filter, but the same principle applies for a large range of statistical computing algorithms.

Typeclasses and monadic collections

In the previous post I gave a quick introduction to the monad concept, and to monadic collections in particular. Many computational tasks in statistics can be accomplished using a sequence of operations on monadic collections. We would like to write code that is independent of any particular implementation of a monadic collection, so that we can switch to a different implementation without changing the code of our algorithm (for example, switching from a serial to a parallel collection). But in strongly typed languages we need to know at compile time that the collection we use has the methods that we require. Typeclasses provide a nice solution to this problem. I don’t want to get bogged down in a big discussion about Scala typeclasses here, but suffice to say that they describe a family of types conforming to a particular interface in an ad hoc loosely coupled way (they are said to provide ad hoc polymorphism). They are not the same as classes in traditional O-O languages, but they do solve a similar problem to the adaptor design pattern, in a much cleaner way. We can describe a simple typeclass for our monadic collection as follows:

trait GenericColl[C[_]] {
  def map[A, B](ca: C[A])(f: A => B): C[B]
  def reduce[A](ca: C[A])(f: (A, A) => A): A
  def flatMap[A, B, D[B] <: GenTraversable[B]](ca: C[A])(f: A => D[B]): C[B]
  def zip[A, B](ca: C[A])(cb: C[B]): C[(A, B)]
  def length[A](ca: C[A]): Int
}

In the typeclass we just list the methods that we expect our generic collection to provide, but do not say anything about how they are implemented. For example, we know that operations such as map and reduce can be executed in parallel, but this is a separate concern. We can now write code that can be used for any collection conforming to the requirements of this typeclass. The full code for this example is provided in the associated github repo for this blog, and includes the obvious syntax for this typeclass, and typeclass instances for the Scala collections Vector and ParVector, that we will exploit later in the example.

SIR step for a bootstrap filter

We can now write some code for a single observation update of a bootstrap particle filter.

def update[S: State, O: Observation, C[_]: GenericColl](
  dataLik: (S, O) => LogLik, stepFun: S => S
)(x: C[S], o: O): (LogLik, C[S]) = {
  val xp = x map (stepFun(_))
  val lw = xp map (dataLik(_, o))
  val max = lw reduce (math.max(_, _))
  val rw = lw map (lwi => math.exp(lwi - max))
  val srw = rw reduce (_ + _)
  val l = rw.length
  val z = rw zip xp
  val rx = z flatMap (p => Vector.fill(Poisson(p._1 * l / srw).draw)(p._2))
  (max + math.log(srw / l), rx)
}

This is a very simple bootstrap filter, using Poisson resampling for simplicity and data locality, but does include use of the log-sum-exp trick to prevent over/underflow of raw weight calculations, and tracks the marginal (log-)likelihood of the observation. With this function we can now pass in a “prior” particle distribution in any collection conforming to our typeclass, together with a propagator function, an observation (log-)likelihood, and an observation, and it will return back a new collection of particles of exactly the same type that was provided for input. Note that all of the operations we require can be accomplished with the standard monadic collection operations declared in our typeclass.

Filtering as a functional fold

Once we have a function for executing one step of a particle filter, we can produce a function for particle filtering as a functional fold over a sequence of observations:

def pFilter[S: State, O: Observation, C[_]: GenericColl, D[O] <: GenTraversable[O]](
  x0: C[S], data: D[O], dataLik: (S, O) => LogLik, stepFun: S => S
): (LogLik, C[S]) = {
  val updater = update[S, O, C](dataLik, stepFun) _
  data.foldLeft((0.0, x0))((prev, o) => {
    val next = updater(prev._2, o)
    (prev._1 + next._1, next._2)
  })
}

Folding data structures is a fundamental concept in functional programming, and is exactly what is required for any kind of filtering problem. Note that Brian Beckman has recently written a series of articles on Kalman filtering as a functional fold.

Marginal likelihoods and parameter estimation

So far we haven’t said anything about parameters or parameter estimation, but this is appropriate, since parametrisation is a separate concern from filtering. However, once we have a function for particle filtering, we can produce a function concerned with evaluating marginal likelihoods trivially:

def pfMll[S: State, P: Parameter, O: Observation, 
            C[_]: GenericColl, D[O] <: GenTraversable[O]](
  simX0: P => C[S], stepFun: P => S => S, 
  dataLik: P => (S, O) => LogLik, data: D[O]
): (P => LogLik) = (th: P) => 
       pFilter(simX0(th), data, dataLik(th), stepFun(th))._1

Note that this higher-order function does not return a value, but instead a function which will accept a parameter as input and return a (log-)likelihood as output. This can then be used for parameter estimation purposes, perhaps being used in a PMMH pMCMC algorithm, or something else. Again, this is a separate concern.

Example

Here I’ll just give a completely trivial toy example, purely to show how the functions work. For avoidance of doubt, I know that there are many better/simpler/easier ways to tackle this problem! Here we will just look at inferring the auto-regression parameter of a linear Gaussian AR(1)-plus-noise model using the functions we have developed.

First we can simulate some synthetic data from this model, using a value of 0.8 for the auto-regression parameter:

val inNoise = Gaussian(0.0, 1.0).sample(99)
val state = DenseVector(inNoise.scanLeft(0.0)((s, i) => 0.8 * s + i).toArray)
val noise = DenseVector(Gaussian(0.0, 2.0).sample(100).toArray)
val data = (state + noise).toArray.toList

Now assuming that we don’t know the auto-regression parameter, we can construct a function to evaluate the likelihood of different parameter values as follows:

val mll = pfMll(
  (th: Double) => Gaussian(0.0, 10.0).sample(10000).toVector.par,
  (th: Double) => (s: Double) => Gaussian(th * s, 1.0).draw,
  (th: Double) => (s: Double, o: Double) => Gaussian(s, 2.0).logPdf(o),
  data
)

Note that the 4 characters “.par” at the end of line 2 are the only difference between running this code serially or in parallel! Now we can run this code by calling the returned function with different values. So, hopefully mll(0.8) will return a larger log-likelihood than (say) mll(0.6) or mll(0.9). The example code in the github repo plots the results of calling mll() for a range of values (note that if that was the genuine use-case, then it would be much better to parallellise the parameter range than the particle filter, due to providing better parallelisation granularity, but many other examples require parallelisation of the particle filter itself). In this particular example, both the forward model and the likelihood are very cheap operations, so there is little to be gained from parallelisation. Nevertheless, I still get a speedup of more than a factor of two using the parallel version on my laptop.

Conclusion

In this post we have shown how typeclasses can be used in Scala to write code that is parallelisation-agnostic. Code written in this way can be run on one or many cores as desired. We’ve illustrated the concept with a scalable particle filter, but nothing about the approach is specific to that application. It would be easy to build up a library of statistical routines this way, all of which can effectively exploit available parallel hardware. Further, although we haven’t demonstrated it here, it is trivial to extend this idea to allow code to be distribution over a cluster of parallel machines if necessary. For example, if an Apache Spark cluster is available, it is easy to make a Spark RDD instance for our generic collection typeclass, that will then allow us to run our (unmodified) particle filter code over a Spark cluster. This emphasises the fact that Spark can be useful for distributing computation as well as just processing “big data”. I’ll say more about Spark in subsequent posts.