MCMC as a Stream

Introduction

This weekend I’ve been preparing some material for my upcoming Scala for statistical computing short course. As part of the course, I thought it would be useful to walk through how to think about and structure MCMC codes, and in particular, how to think about MCMC algorithms as infinite streams of state. This material is reasonably stand-alone, so it seems suitable for a blog post. Complete runnable code for the examples in this post are available from my blog repo.

A simple MH sampler

For this post I will just consider a trivial toy Metropolis algorithm using a Uniform random walk proposal to target a standard normal distribution. I’ve considered this problem before on my blog, so if you aren’t very familiar with Metropolis-Hastings algorithms, you might want to quickly review my post on Metropolis-Hastings MCMC algorithms in R before continuing. At the end of that post, I gave the following R code for the Metropolis sampler:

metrop3<-function(n=1000,eps=0.5) 
{
        vec=vector("numeric", n)
        x=0
        oldll=dnorm(x,log=TRUE)
        vec[1]=x
        for (i in 2:n) {
                can=x+runif(1,-eps,eps)
                loglik=dnorm(can,log=TRUE)
                loga=loglik-oldll
                if (log(runif(1)) < loga) { 
                        x=can
                        oldll=loglik
                        }
                vec[i]=x
        }
        vec
}

I will begin this post with a fairly direct translation of this algorithm into Scala:

def metrop1(n: Int = 1000, eps: Double = 0.5): DenseVector[Double] = {
    val vec = DenseVector.fill(n)(0.0)
    var x = 0.0
    var oldll = Gaussian(0.0, 1.0).logPdf(x)
    vec(0) = x
    (1 until n).foreach { i =>
      val can = x + Uniform(-eps, eps).draw
      val loglik = Gaussian(0.0, 1.0).logPdf(can)
      val loga = loglik - oldll
      if (math.log(Uniform(0.0, 1.0).draw) < loga) {
        x = can
        oldll = loglik
      }
      vec(i) = x
    }
    vec
}

This code works, and is reasonably fast and efficient, but there are several issues with it from a functional programmers perspective. One issue is that we have committed to storing all MCMC output in RAM in a DenseVector. This probably isn’t an issue here, but for some big problems we might prefer to not store the full set of states, but to just print the states to (say) the console, for possible re-direction to a file. It is easy enough to modify the code to do this:

def metrop2(n: Int = 1000, eps: Double = 0.5): Unit = {
    var x = 0.0
    var oldll = Gaussian(0.0, 1.0).logPdf(x)
    (1 to n).foreach { i =>
      val can = x + Uniform(-eps, eps).draw
      val loglik = Gaussian(0.0, 1.0).logPdf(can)
      val loga = loglik - oldll
      if (math.log(Uniform(0.0, 1.0).draw) < loga) {
        x = can
        oldll = loglik
      }
      println(x)
    }
}

But now we have two version of the algorithm. One for storing results locally, and one for streaming results to the console. This is clearly unsatisfactory, but we shall return to this issue shortly. Another issue that will jump out at functional programmers is the reliance on mutable variables for storing the state and old likelihood. Let’s fix that now by re-writing the algorithm as a tail-recursion.

@tailrec
def metrop3(n: Int = 1000, eps: Double = 0.5, x: Double = 0.0, oldll: Double = Double.MinValue): Unit = {
    if (n > 0) {
      println(x)
      val can = x + Uniform(-eps, eps).draw
      val loglik = Gaussian(0.0, 1.0).logPdf(can)
      val loga = loglik - oldll
      if (math.log(Uniform(0.0, 1.0).draw) < loga)
        metrop3(n - 1, eps, can, loglik)
      else
        metrop3(n - 1, eps, x, oldll)
    }
  }

This has eliminated the vars, and is just as fast and efficient as the previous version of the code. Note that the @tailrec annotation is optional – it just signals to the compiler that we want it to throw an error if for some reason it cannot eliminate the tail call. However, this is for the print-to-console version of the code. What if we actually want to keep the iterations in RAM for subsequent analysis? We can keep the values in an accumulator, as follows.

@tailrec
def metrop4(n: Int = 1000, eps: Double = 0.5, x: Double = 0.0, oldll: Double = Double.MinValue, acc: List[Double] = Nil): DenseVector[Double] = {
    if (n == 0)
      DenseVector(acc.reverse.toArray)
    else {
      val can = x + Uniform(-eps, eps).draw
      val loglik = Gaussian(0.0, 1.0).logPdf(can)
      val loga = loglik - oldll
      if (math.log(Uniform(0.0, 1.0).draw) < loga)
        metrop4(n - 1, eps, can, loglik, can :: acc)
      else
        metrop4(n - 1, eps, x, oldll, x :: acc)
    }
}

Factoring out the updating logic

This is all fine, but we haven’t yet addressed the issue of having different versions of the code depending on what we want to do with the output. The problem is that we have tied up the logic of advancing the Markov chain with what to do with the output. What we need to do is separate out the code for advancing the state. We can do this by defining a new function.

def newState(x: Double, oldll: Double, eps: Double): (Double, Double) = {
    val can = x + Uniform(-eps, eps).draw
    val loglik = Gaussian(0.0, 1.0).logPdf(can)
    val loga = loglik - oldll
    if (math.log(Uniform(0.0, 1.0).draw) < loga) (can, loglik) else (x, oldll)
}

This function takes as input a current state and associated log likelihood and returns a new state and log likelihood following the execution of one step of a MH algorithm. This separates the concern of state updating from the rest of the code. So now if we want to write code that prints the state, we can write it as

  @tailrec
  def metrop5(n: Int = 1000, eps: Double = 0.5, x: Double = 0.0, oldll: Double = Double.MinValue): Unit = {
    if (n > 0) {
      println(x)
      val ns = newState(x, oldll, eps)
      metrop5(n - 1, eps, ns._1, ns._2)
    }
  }

and if we want to accumulate the set of states visited, we can write that as

  @tailrec
  def metrop6(n: Int = 1000, eps: Double = 0.5, x: Double = 0.0, oldll: Double = Double.MinValue, acc: List[Double] = Nil): DenseVector[Double] = {
    if (n == 0) DenseVector(acc.reverse.toArray) else {
      val ns = newState(x, oldll, eps)
      metrop6(n - 1, eps, ns._1, ns._2, ns._1 :: acc)
    }
  }

Both of these functions call newState to do the real work, and concentrate on what to do with the sequence of states. However, both of these functions repeat the logic of how to iterate over the sequence of states.

MCMC as a stream

Ideally we would like to abstract out the details of how to do state iteration from the code as well. Most functional languages have some concept of a Stream, which represents a (potentially infinite) sequence of states. The Stream can embody the logic of how to perform state iteration, allowing us to abstract that away from our code, as well.

To do this, we will restructure our code slightly so that it more clearly maps old state to new state.

def nextState(eps: Double)(state: (Double, Double)): (Double, Double) = {
    val x = state._1
    val oldll = state._2
    val can = x + Uniform(-eps, eps).draw
    val loglik = Gaussian(0.0, 1.0).logPdf(can)
    val loga = loglik - oldll
    if (math.log(Uniform(0.0, 1.0).draw) < loga) (can, loglik) else (x, oldll)
}

The "real" state of the chain is just x, but if we want to avoid recalculation of the old likelihood, then we need to make this part of the chain’s state. We can use this nextState function in order to construct a Stream.

  def metrop7(eps: Double = 0.5, x: Double = 0.0, oldll: Double = Double.MinValue): Stream[Double] =
    Stream.iterate((x, oldll))(nextState(eps)) map (_._1)

The result of calling this is an infinite stream of states. Obviously it isn’t computed – that would require infinite computation, but it captures the logic of iteration and computation in a Stream, that can be thought of as a lazy List. We can get values out by converting the Stream to a regular collection, being careful to truncate the Stream to one of finite length beforehand! eg. metrop7().drop(1000).take(10000).toArray will do a burn-in of 1,000 iterations followed by a main monitoring run of length 10,000, capturing the results in an Array. Note that metrop7().drop(1000).take(10000) is a Stream, and so nothing is actually computed until the toArray is encountered. Conversely, if printing to console is required, just replace the .toArray with .foreach(println).

The above stream-based approach to MCMC iteration is clean and elegant, and deals nicely with issues like burn-in and thinning (which can be handled similarly). This is how I typically write MCMC codes these days. However, functional programming purists would still have issues with this approach, as it isn’t quite pure functional. The problem is that the code isn’t pure – it has a side-effect, which is to mutate the state of the under-pinning pseudo-random number generator. If the code was pure, calling nextState with the same inputs would always give the same result. Clearly this isn’t the case here, as we have specifically designed the function to be stochastic, returning a randomly sampled value from the desired probability distribution. So nextState represents a function for randomly sampling from a conditional probability distribution.

A pure functional approach

Now, ultimately all code has side-effects, or there would be no point in running it! But in functional programming the desire is to make as much of the code as possible pure, and to push side-effects to the very edges of the code. So it’s fine to have side-effects in your main method, but not buried deep in your code. Here the side-effect is at the very heart of the code, which is why it is potentially an issue.

To keep things as simple as possible, at this point we will stop worrying about carrying forward the old likelihood, and hard-code a value of eps. Generalisation is straightforward. We can make our code pure by instead defining a function which represents the conditional probability distribution itself. For this we use a probability monad, which in Breeze is called Rand. We can couple together such functions using monadic binds (flatMap in Scala), expressed most neatly using for-comprehensions. So we can write our transition kernel as

def kernel(x: Double): Rand[Double] = for {
    innov <- Uniform(-0.5, 0.5)
    can = x + innov
    oldll = Gaussian(0.0, 1.0).logPdf(x)
    loglik = Gaussian(0.0, 1.0).logPdf(can)
    loga = loglik - oldll
    u <- Uniform(0.0, 1.0)
} yield if (math.log(u) < loga) can else x

This is now pure – the same input x will always return the same probability distribution – the conditional distribution of the next state given the current state. We can draw random samples from this distribution if we must, but it’s probably better to work as long as possible with pure functions. So next we need to encapsulate the iteration logic. Breeze has a MarkovChain object which can take kernels of this form and return a stochastic Process object representing the iteration logic, as follows.

MarkovChain(0.0)(kernel).
  steps.
  drop(1000).
  take(10000).
  foreach(println)

The steps method contains the logic of how to advance the state of the chain. But again note that no computation actually takes place until the foreach method is encountered – this is when the sampling occurs and the side-effects happen.

Metropolis-Hastings is a common use-case for Markov chains, so Breeze actually has a helper method built-in that will construct a MH sampler directly from an initial state, a proposal kernel, and a (log) target.

MarkovChain.
  metropolisHastings(0.0, (x: Double) =>
  Uniform(x - 0.5, x + 0.5))(x =>
  Gaussian(0.0, 1.0).logPdf(x)).
  steps.
  drop(1000).
  take(10000).
  toArray

Note that if you are using the MH functionality in Breeze, it is important to make sure that you are using version 0.13 (or later), as I fixed a few issues with the MH code shortly prior to the 0.13 release.

Summary

Viewing MCMC algorithms as infinite streams of state is useful for writing elegant, generic, flexible code. Streams occur everywhere in programming, and so there are lots of libraries for working with them. In this post I used the simple Stream from the Scala standard library, but there are much more powerful and flexible stream libraries for Scala, including fs2 and Akka-streams. But whatever libraries you are using, the fundamental concepts are the same. The most straightforward approach to implementation is to define impure stochastic streams to consume. However, a pure functional approach is also possible, and the Breeze library defines some useful functions to facilitate this approach. I’m still a little bit ambivalent about whether the pure approach is worth the additional cognitive overhead, but it’s certainly very interesting and worth playing with and thinking about the pros and cons.

Complete runnable code for the examples in this post are available from my blog repo.

Brief introduction to Scala and Breeze for statistical computing

Introduction

In the previous post I outlined why I think Scala is a good language for statistical computing and data science. In this post I want to give a quick taste of Scala and the Breeze numerical library to whet the appetite of the uninitiated. This post certainly won’t provide enough material to get started using Scala in anger – but I’ll try and provide a few pointers along the way. It also won’t be very interesting to anyone who knows Scala – I’m not introducing any of the very cool Scala stuff here – I think that some of the most powerful and interesting Scala language features can be a bit frightening for new users.

To reproduce the examples, you need to install Scala and Breeze. This isn’t very tricky, but I don’t want to get bogged down with a detailed walk-through here – I want to concentrate on the Scala language and Breeze library. You just need to install a recent version of Java, then Scala, and then Breeze. You might also want SBT and/or the ScalaIDE, though neither of these are necessary. Then you need to run the Scala REPL with the Breeze library in the classpath. There are several ways one can do this. The most obvious is to just run scala with the path to Breeze manually specified (or specified in an environment variable). Alternatively, you could run a console from an sbt session with a Breeze dependency (which is what I actually did for this post), or you could use a Scala Worksheet from inside a ScalaIDE project with a Breeze dependency.

A Scala REPL session

A first glimpse of Scala

We’ll start with a few simple Scala concepts that are not dependent on Breeze. For further information, see the Scala documentation.

Welcome to Scala version 2.10.3 (OpenJDK 64-Bit Server VM, Java 1.7.0_25).
Type in expressions to have them evaluated.
Type :help for more information.

scala> val a = 5
a: Int = 5

scala> a
res0: Int = 5

So far, so good. Using the Scala REPL is much like using the Python or R command line, so will be very familiar to anyone used to these or similar languages. The first thing to note is that labels need to be declared on first use. We have declared a to be a val. These are immutable values, which can not be just re-assigned, as the following code illustrates.

scala> a = 6
<console>:8: error: reassignment to val
       a = 6
         ^
scala> a
res1: Int = 5

Immutability seems to baffle people unfamiliar with functional programming. But fear not, as Scala allows declaration of mutable variables as well:

scala> var b = 7
b: Int = 7

scala> b
res2: Int = 7

scala> b = 8
b: Int = 8

scala> b
res3: Int = 8

The Zen of functional programming is to realise that immutability is generally a good thing, but that really isn’t the point of this post. Scala has excellent support for both mutable and immutable collections as part of the standard library. See the API docs for more details. For example, it has immutable lists.

scala> val c = List(3,4,5,6)
c: List[Int] = List(3, 4, 5, 6)

scala> c(1)
res4: Int = 4

scala> c.sum
res5: Int = 18

scala> c.length
res6: Int = 4

scala> c.product
res7: Int = 360

Again, this should be pretty familiar stuff for anyone familiar with Python. Note that the sum and product methods are special cases of reduce operations, which are well supported in Scala. For example, we could compute the sum reduction using

scala> c.foldLeft(0)((x,y) => x+y)
res8: Int = 18

or the slightly more condensed form given below, and similarly for the product reduction.

scala> c.foldLeft(0)(_+_)
res9: Int = 18

scala> c.foldLeft(1)(_*_)
res10: Int = 360

Scala also has a nice immutable Vector class, which offers a range of constant time operations (but note that this has nothing to do with the mutable Vector class that is part of the Breeze library).

scala> val d = Vector(2,3,4,5,6,7,8,9)
d: scala.collection.immutable.Vector[Int] = Vector(2, 3, 4, 5, 6, 7, 8, 9)

scala> d
res11: scala.collection.immutable.Vector[Int] = Vector(2, 3, 4, 5, 6, 7, 8, 9)

scala> d.slice(3,6)
res12: scala.collection.immutable.Vector[Int] = Vector(5, 6, 7)

scala> val e = d.updated(3,0)
e: scala.collection.immutable.Vector[Int] = Vector(2, 3, 4, 0, 6, 7, 8, 9)

scala> d
res13: scala.collection.immutable.Vector[Int] = Vector(2, 3, 4, 5, 6, 7, 8, 9)

scala> e
res14: scala.collection.immutable.Vector[Int] = Vector(2, 3, 4, 0, 6, 7, 8, 9)

Note that when e is created as an updated version of d the whole of d is not copied – only the parts that have been updated. And we don’t have to worry that aspects of d and e point to the same information in memory, as they are both immutable… As should be clear by now, Scala has excellent support for functional programming techniques. In addition to the reduce operations mentioned already, maps and filters are also well covered.

scala> val f=(1 to 10).toList
f: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

scala> f
res15: List[Int] = List(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)

scala> f.map(x => x*x)
res16: List[Int] = List(1, 4, 9, 16, 25, 36, 49, 64, 81, 100)

scala> f map {x => x*x}
res17: List[Int] = List(1, 4, 9, 16, 25, 36, 49, 64, 81, 100)

scala> f filter {_ > 4}
res18: List[Int] = List(5, 6, 7, 8, 9, 10)

Note how Scala allows methods with a single argument to be written as an infix operator, making for more readable code.

A first look at Breeze

The next part of the session requires the Breeze library – see the Breeze quickstart guide for further details. We begin by taking a quick look at everyone’s favourite topic of non-uniform random number generation. Let’s start by generating a couple of draws from a Poisson distribution with mean 3.

scala> import breeze.stats.distributions._
import breeze.stats.distributions._

scala> val poi = Poisson(3.0)
poi: breeze.stats.distributions.Poisson = Poisson(3.0)

scala> poi.draw
res19: Int = 2

scala> poi.draw
res20: Int = 3

If more than a single draw is required, an iid sample can be obtained.

scala> val x = poi.sample(10)
x: IndexedSeq[Int] = Vector(2, 3, 3, 4, 2, 2, 1, 2, 4, 2)

scala> x
res21: IndexedSeq[Int] = Vector(2, 3, 3, 4, 2, 2, 1, 2, 4, 2)

scala> x.sum
res22: Int = 25

scala> x.length
res23: Int = 10

scala> x.sum.toDouble/x.length
res24: Double = 2.5

Note that this Vector is mutable. The probability mass function (PMF) of the Poisson distribution is also available.

scala> poi.probabilityOf(2)
res25: Double = 0.22404180765538775

scala> x map {x => poi.probabilityOf(x)}
res26: IndexedSeq[Double] = Vector(0.22404180765538775, 0.22404180765538775, 0.22404180765538775, 0.16803135574154085, 0.22404180765538775, 0.22404180765538775, 0.14936120510359185, 0.22404180765538775, 0.16803135574154085, 0.22404180765538775)

scala> x map {poi.probabilityOf(_)}
res27: IndexedSeq[Double] = Vector(0.22404180765538775, 0.22404180765538775, 0.22404180765538775, 0.16803135574154085, 0.22404180765538775, 0.22404180765538775, 0.14936120510359185, 0.22404180765538775, 0.16803135574154085, 0.22404180765538775)

Obviously, Gaussian variables (and Gamma, and several others) are supported in a similar way.

scala> val gau=Gaussian(0.0,1.0)
gau: breeze.stats.distributions.Gaussian = Gaussian(0.0, 1.0)

scala> gau.draw
res28: Double = 1.606121255846881

scala> gau.draw
res29: Double = -0.1747896055492152

scala> val y=gau.sample(20)
y: IndexedSeq[Double] = Vector(-1.3758577012869702, -1.2148314970824652, -0.022501190144116855, 0.3244006323566883, 0.35978577573558407, 0.9651857500320781, -0.40834034207848985, 0.11583348205331555, -0.8797699986810634, -0.33609738668214695, 0.7043252811790879, -1.2045594639823656, 0.19442688045065826, -0.31442160076087067, 0.06313451540562891, -1.5304745838587115, -1.2372764884467027, 0.5875490994217284, -0.9385520597707431, -0.6647903243363228)

scala> y
res30: IndexedSeq[Double] = Vector(-1.3758577012869702, -1.2148314970824652, -0.022501190144116855, 0.3244006323566883, 0.35978577573558407, 0.9651857500320781, -0.40834034207848985, 0.11583348205331555, -0.8797699986810634, -0.33609738668214695, 0.7043252811790879, -1.2045594639823656, 0.19442688045065826, -0.31442160076087067, 0.06313451540562891, -1.5304745838587115, -1.2372764884467027, 0.5875490994217284, -0.9385520597707431, -0.6647903243363228)

scala> y.sum/y.length
res31: Double = -0.34064156102380994

scala> y map {gau.logPdf(_)}
res32: IndexedSeq[Double] = Vector(-1.8654307403000054, -1.6568463163564844, -0.9191916849836235, -0.9715564183413823, -0.9836614354155007, -1.3847302992371653, -1.0023094506890617, -0.9256472309869705, -1.3059361584943119, -0.975419259871957, -1.1669755840586733, -1.6444202843394145, -0.93783943912556, -0.9683690047171869, -0.9209315167224245, -2.090114759123421, -1.6843650876361744, -1.0915455053203147, -1.359378517654625, -1.1399116208702693)

scala> Gamma(2.0,3.0).sample(5)
res33: IndexedSeq[Double] = Vector(2.38436441278546, 2.125017198373521, 2.333118708811143, 5.880076392566909, 2.0901427084667503)

This is all good stuff for those of us who like to do Markov chain Monte Carlo. There are not masses of statistical data analysis routines built into Breeze, but a few basic tools are provided, including some basic summary statistics.

scala> import breeze.stats.DescriptiveStats._
import breeze.stats.DescriptiveStats._

scala> mean(y)
res34: Double = -0.34064156102380994

scala> variance(y)
res35: Double = 0.574257149387757

scala> meanAndVariance(y)
res36: (Double, Double) = (-0.34064156102380994,0.574257149387757)

Support for linear algebra is an important part of any scientific library. Here the Breeze developers have made the wise decision to provide a nice Scala interface to netlib-java. This in turn calls out to any native optimised BLAS or LAPACK libraries installed on the system, but will fall back to Java code if no optimised libraries are available. This means that linear algebra code using Scala and Breeze should run as fast as code written in any other language, including C, C++ and Fortran, provided that optimised libraries are installed on the system. For further details see the Breeze linear algebra guide. Let’s start by creating and messing with a dense vector.

scala> import breeze.linalg._
import breeze.linalg._

scala> val v=DenseVector(y.toArray)
v: breeze.linalg.DenseVector[Double] = DenseVector(-1.3758577012869702, -1.2148314970824652, -0.022501190144116855, 0.3244006323566883, 0.35978577573558407, 0.9651857500320781, -0.40834034207848985, 0.11583348205331555, -0.8797699986810634, -0.33609738668214695, 0.7043252811790879, -1.2045594639823656, 0.19442688045065826, -0.31442160076087067, 0.06313451540562891, -1.5304745838587115, -1.2372764884467027, 0.5875490994217284, -0.9385520597707431, -0.6647903243363228)

scala> v(1) = 0

scala> v
res38: breeze.linalg.DenseVector[Double] = DenseVector(-1.3758577012869702, 0.0, -0.022501190144116855, 0.3244006323566883, 0.35978577573558407, 0.9651857500320781, -0.40834034207848985, 0.11583348205331555, -0.8797699986810634, -0.33609738668214695, 0.7043252811790879, -1.2045594639823656, 0.19442688045065826, -0.31442160076087067, 0.06313451540562891, -1.5304745838587115, -1.2372764884467027, 0.5875490994217284, -0.9385520597707431, -0.6647903243363228)

scala> v(1 to 3) := 1.0
res39: breeze.linalg.DenseVector[Double] = DenseVector(1.0, 1.0, 1.0)

scala> v
res40: breeze.linalg.DenseVector[Double] = DenseVector(-1.3758577012869702, 1.0, 1.0, 1.0, 0.35978577573558407, 0.9651857500320781, -0.40834034207848985, 0.11583348205331555, -0.8797699986810634, -0.33609738668214695, 0.7043252811790879, -1.2045594639823656, 0.19442688045065826, -0.31442160076087067, 0.06313451540562891, -1.5304745838587115, -1.2372764884467027, 0.5875490994217284, -0.9385520597707431, -0.6647903243363228)

scala> v(1 to 3) := DenseVector(1.0,1.5,2.0)
res41: breeze.linalg.DenseVector[Double] = DenseVector(1.0, 1.5, 2.0)

scala> v
res42: breeze.linalg.DenseVector[Double] = DenseVector(-1.3758577012869702, 1.0, 1.5, 2.0, 0.35978577573558407, 0.9651857500320781, -0.40834034207848985, 0.11583348205331555, -0.8797699986810634, -0.33609738668214695, 0.7043252811790879, -1.2045594639823656, 0.19442688045065826, -0.31442160076087067, 0.06313451540562891, -1.5304745838587115, -1.2372764884467027, 0.5875490994217284, -0.9385520597707431, -0.6647903243363228)

scala> v :> 0.0
res43: breeze.linalg.BitVector = BitVector(1, 2, 3, 4, 5, 7, 10, 12, 14, 17)

scala> (v :> 0.0).toArray
res44: Array[Boolean] = Array(false, true, true, true, true, true, false, true, false, false, true, false, true, false, true, false, false, true, false, false)

Next let’s create and mess around with some dense matrices.

scala> val m = new DenseMatrix(5,4,linspace(1.0,20.0,20).toArray)
m: breeze.linalg.DenseMatrix[Double] = 
1.0  6.0   11.0  16.0  
2.0  7.0   12.0  17.0  
3.0  8.0   13.0  18.0  
4.0  9.0   14.0  19.0  
5.0  10.0  15.0  20.0  

scala> m
res45: breeze.linalg.DenseMatrix[Double] = 
1.0  6.0   11.0  16.0  
2.0  7.0   12.0  17.0  
3.0  8.0   13.0  18.0  
4.0  9.0   14.0  19.0  
5.0  10.0  15.0  20.0  

scala> m.rows
res46: Int = 5

scala> m.cols
res47: Int = 4

scala> m(::,1)
res48: breeze.linalg.DenseVector[Double] = DenseVector(6.0, 7.0, 8.0, 9.0, 10.0)

scala> m(1,::)
res49: breeze.linalg.DenseMatrix[Double] = 2.0  7.0  12.0  17.0  

scala> m(1,::) := linspace(1.0,2.0,4)
res50: breeze.linalg.DenseMatrix[Double] = 1.0  1.3333333333333333  1.6666666666666665  2.0  

scala> m
res51: breeze.linalg.DenseMatrix[Double] = 
1.0  6.0                 11.0                16.0  
1.0  1.3333333333333333  1.6666666666666665  2.0   
3.0  8.0                 13.0                18.0  
4.0  9.0                 14.0                19.0  
5.0  10.0                15.0                20.0  

scala> 

scala> val n = m.t
n: breeze.linalg.DenseMatrix[Double] = 
1.0   1.0                 3.0   4.0   5.0   
6.0   1.3333333333333333  8.0   9.0   10.0  
11.0  1.6666666666666665  13.0  14.0  15.0  
16.0  2.0                 18.0  19.0  20.0  

scala> n
res52: breeze.linalg.DenseMatrix[Double] = 
1.0   1.0                 3.0   4.0   5.0   
6.0   1.3333333333333333  8.0   9.0   10.0  
11.0  1.6666666666666665  13.0  14.0  15.0  
16.0  2.0                 18.0  19.0  20.0  

scala> val o = m*n
o: breeze.linalg.DenseMatrix[Double] = 
414.0              59.33333333333333  482.0              516.0              550.0              
59.33333333333333  9.555555555555555  71.33333333333333  77.33333333333333  83.33333333333333  
482.0              71.33333333333333  566.0              608.0              650.0              
516.0              77.33333333333333  608.0              654.0              700.0              
550.0              83.33333333333333  650.0              700.0              750.0              

scala> o
res53: breeze.linalg.DenseMatrix[Double] = 
414.0              59.33333333333333  482.0              516.0              550.0              
59.33333333333333  9.555555555555555  71.33333333333333  77.33333333333333  83.33333333333333  
482.0              71.33333333333333  566.0              608.0              650.0              
516.0              77.33333333333333  608.0              654.0              700.0              
550.0              83.33333333333333  650.0              700.0              750.0              

scala> val p = n*m
p: breeze.linalg.DenseMatrix[Double] = 
52.0                117.33333333333333  182.66666666666666  248.0              
117.33333333333333  282.77777777777777  448.22222222222223  613.6666666666667  
182.66666666666666  448.22222222222223  713.7777777777778   979.3333333333334  
248.0               613.6666666666667   979.3333333333334   1345.0             

scala> p
res54: breeze.linalg.DenseMatrix[Double] = 
52.0                117.33333333333333  182.66666666666666  248.0              
117.33333333333333  282.77777777777777  448.22222222222223  613.6666666666667  
182.66666666666666  448.22222222222223  713.7777777777778   979.3333333333334  
248.0               613.6666666666667   979.3333333333334   1345.0             

So, messing around with vectors and matrices is more-or-less as convenient as in well-known dynamic and math languages. To conclude this section, let us see how to simulate some data from a regression model and then solve the least squares problem to obtain the estimated regression coefficients. We will simulate 1,000 observations from a model with 5 covariates.

scala> val X = new DenseMatrix(1000,5,gau.sample(5000).toArray)
X: breeze.linalg.DenseMatrix[Double] = 
-0.40186606934180685  0.9847148198711287    ... (5 total)
-0.4760404521336951   -0.833737041320742    ...
-0.3315199616926892   -0.19460446824586297  ...
-0.14764615494496836  -0.17947658245206904  ...
-0.8357372755800905   -2.456222113596015    ...
-0.44458309216683184  1.848007773944826     ...
0.060314034896221065  0.5254462055311016    ...
0.8637867740789016    -0.9712570453363925   ...
0.11620167261655819   -1.2231380938032232   ...
-0.3335514290842617   -0.7487303696662753   ...
-0.5598937433421866   0.11083382409013512   ...
-1.7213395389510568   1.1717491221846357    ...
-1.078873342208984    0.9386859686451607    ...
-0.7793854546738327   -0.9829373863442161   ...
-1.054275201631216    0.10100826507456745   ...
-0.6947188686537832   1.215...
scala> val b0 = linspace(1.0,2.0,5)
b0: breeze.linalg.DenseVector[Double] = DenseVector(1.0, 1.25, 1.5, 1.75, 2.0)

scala> val y0 = X * b0
y0: breeze.linalg.DenseVector[Double] = DenseVector(0.08200546839589107, -0.5992571365601228, -5.646398002309553, -7.346136663325798, -8.486423788193362, 1.451119214541837, -0.25792385841948406, 2.324936340609002, -1.2285599639827862, -4.030261316643863, -4.1732627416377674, -0.5077151099958077, -0.2087263741903591, 0.46678616461409383, 2.0244342278575975, 1.775756468177401, -4.799821190728213, -1.8518388060564481, 1.5892306875621767, -1.6528539564387008, 1.4064864330994125, -0.8734630221484178, -7.75470002781836, -0.2893619536998493, -5.972958583649336, -4.952666733286302, 0.5431255990489059, -2.477076684976403, -0.6473617571867107, -0.509338416957489, -1.5415350935719594, -0.47068802465681125, 2.546118380362026, -7.940401988804477, -1.037049442788122, -1.564016663370888, -3.3147087994...
scala> val y = y0 + DenseVector(gau.sample(1000).toArray)
y: breeze.linalg.DenseVector[Double] = DenseVector(-0.572127338358624, -0.16481167194161406, -4.213873268823003, -10.142015065601388, -7.893898543052863, 1.7881055848475076, -0.26987820512025357, 3.3289433195054148, -2.514141419925489, -4.643625974157769, -3.8061000214061886, 0.6462624993109218, 0.23603338389134149, 1.0211137806779267, 2.0061727641393317, 0.022624943149799348, -5.429601401989341, -1.836181225242386, 1.0265599173053048, -0.1673732536615371, 0.8418249443853956, -1.1547110533101967, -8.392100167478764, -1.1586377992526877, -6.400362975646245, -5.487018086963841, 0.3038055584347069, -1.2247410435868684, -0.06476921390724344, -1.5039074374120407, -1.0189111630970076, 1.307339668865724, 2.048320821568789, -8.769328824477714, -0.9104251029228555, -1.3533910178496698, -2.178788...
scala> val b = X \ y  // defaults to a QR-solve of the least squares problem
b: breeze.linalg.DenseVector[Double] = DenseVector(0.9952708232116663, 1.2344546192238952, 1.5543512339052412, 1.744091673457169, 1.9874158953720507)

So all of the most important building blocks for statistical computing are included in the Breeze library.

At this point it is really worth reminding yourself that Scala is actually a statically typed language, despite the fact that in this session we have not explicitly declared the type of anything at all! This is because Scala has type inference, which makes type declarations optional when it is straightforward for the compiler to figure out what the types must be. For example, for our very first expression, val a = 5, because the RHS is an Int, it is clear that the LHS must also be an Int, and so the compiler infers that the type of a must be an Int, and treats the code as if the type had been declared as val a: Int = 5. This type inference makes Scala feel very much like a dynamic language in general use. Typically, we carefully specify the types of function arguments (and often the return type of the function, too), but then for the main body of each function, just let the compiler figure out all of the types and write code as if the language were dynamic. To me, this seems like the best of all worlds. The convenience of dynamic languages with the safety of static typing.

Declaring the types of function arguments is not usually a big deal, as the following simple example demonstrates.

scala> def mean(arr: Array[Int]): Double = {
     |   arr.sum.toDouble/arr.length
     | }
mean: (arr: Array[Int])Double

scala> mean(Array(3,1,4,5))
res55: Double = 3.25

A complete Scala program

For completeness, I will finish this post with a very simple but complete Scala/Breeze program. In a previous post I discussed a simple Gibbs sampler in Scala, but in that post I used the Java COLT library for random number generation. Below is a version using Breeze instead.

object BreezeGibbs {

  import breeze.stats.distributions._
  import scala.math.sqrt

  class State(val x: Double, val y: Double)

  def nextIter(s: State): State = {
    val newX = Gamma(3.0, 1.0 / ((s.y) * (s.y) + 4.0)).draw()
    new State(newX, Gaussian(1.0 / (newX + 1), 1.0 / sqrt(2 * newX + 2)).draw())
  }

  def nextThinnedIter(s: State, left: Int): State = {
    if (left == 0) s
    else nextThinnedIter(nextIter(s), left - 1)
  }

  def genIters(s: State, current: Int, stop: Int, thin: Int): State = {
    if (!(current > stop)) {
      println(current + " " + s.x + " " + s.y)
      genIters(nextThinnedIter(s, thin), current + 1, stop, thin)
    } else s
  }

  def main(args: Array[String]) {
    println("Iter x y")
    genIters(new State(0.0, 0.0), 1, 50000, 1000)
  }

}

Summary

In this post I’ve tried to give a quick taste of the Scala language and the Breeze library for those used to dynamic languages for statistical computing. Hopefully I’ve illustrated that the basics don’t look too different, so there is no reason to fear Scala. It is perfectly possible to start using Scala as a better and faster Python or R. Once you’ve mastered the basics, you can then start exploring the full power of the language. There’s loads of introductory Scala material to be found on-line. It probably makes sense to start with the links I’ve highlighted above. After that, just start searching – there’s an interesting set of tutorials I noticed just the other day. A very time-efficient way to learn Scala quickly is to do the FP with Scala course on Coursera, but whether this makes sense will depend on when it is next running. For those who prefer real books, the book Programming in Scala is the standard reference, and I’ve also found Functional programming in Scala to be useful (free text of the first edition of the former and a draft of the latter can be found on-line).

REPL Script

Below is a copy of the complete REPL script, for reference.

// start with non-Breeze stuff

val a = 5
a
a = 6
a

var b = 7
b
b = 8
b

val c = List(3,4,5,6)
c(1)
c.sum
c.length
c.product
c.foldLeft(0)((x,y) => x+y)
c.foldLeft(0)(_+_)
c.foldLeft(1)(_*_)

val d = Vector(2,3,4,5,6,7,8,9)
d
d.slice(3,6)
val e = d.updated(3,0)
d
e

val f=(1 to 10).toList
f
f.map(x => x*x)
f map {x => x*x}
f filter {_ > 4}

// introduce breeze through random distributions
// https://github.com/scalanlp/breeze/wiki/Quickstart

import breeze.stats.distributions._
val poi = Poisson(3.0)
poi.draw
poi.draw
val x = poi.sample(10)
x
x.sum
x.length
x.sum.toDouble/x.length
poi.probabilityOf(2)
x map {x => poi.probabilityOf(x)}
x map {poi.probabilityOf(_)}

val gau=Gaussian(0.0,1.0)
gau.draw
gau.draw
val y=gau.sample(20)
y
y.sum/y.length
y map {gau.logPdf(_)}

Gamma(2.0,3.0).sample(5)

import breeze.stats.DescriptiveStats._
mean(y)
variance(y)
meanAndVariance(y)


// move on to linear algebra
// https://github.com/scalanlp/breeze/wiki/Breeze-Linear-Algebra

import breeze.linalg._
val v=DenseVector(y.toArray)
v(1) = 0
v
v(1 to 3) := 1.0
v
v(1 to 3) := DenseVector(1.0,1.5,2.0)
v
v :> 0.0
(v :> 0.0).toArray

val m = new DenseMatrix(5,4,linspace(1.0,20.0,20).toArray)
m
m.rows
m.cols
m(::,1)
m(1,::)
m(1,::) := linspace(1.0,2.0,4)
m

val n = m.t
n
val o = m*n
o
val p = n*m
p

// regression and QR solution

val X = new DenseMatrix(1000,5,gau.sample(5000).toArray)
val b0 = linspace(1.0,2.0,5)
val y0 = X * b0
val y = y0 + DenseVector(gau.sample(1000).toArray)
val b = X \ y  // defaults to a QR-solve of the least squares problem

// a simple function example

def mean(arr: Array[Int]): Double = {
  arr.sum.toDouble/arr.length
}

mean(Array(3,1,4,5))