scala-glm: Regression modelling in Scala

Introduction

As discussed in the previous post, I’ve recently constructed and delivered a short course on statistical computing with Scala. Much of the course is concerned with writing statistical algorithms in Scala, typically making use of the scientific and numerical computing library, Breeze. Breeze has all of the essential tools necessary for building statistical algorithms, but doesn’t contain any higher level modelling functionality. As part of the course, I walked through how to build a small library for regression modelling on top of Breeze, including all of the usual regression diagnostics (such as standard errors, t-statistics, p-values, F-statistics, etc.). While preparing the course materials it occurred to me that it would be useful to package and document this code properly for general use. In advance of the course I packaged the code up into a bare-bones library, but since then I’ve fleshed it out, tidied it up and documented it properly, so it’s now ready for people to use.

The library covers PCA, linear regression modelling and simple one-parameter GLMs (including logistic and Poisson regression). The underlying algorithms are fairly efficient and numerically stable (eg. linear regression uses the QR decomposition of the model matrix, and the GLM fitting uses QR within each IRLS step), though they are optimised more for clarity than speed. The library also includes a few utility functions and procedures, including a pairs plot (scatter-plot matrix).

A linear regression example

Plenty of documentation is available from the scala-glm github repo which I won’t repeat here. But to give a rough idea of how things work, I’ll run through an interactive session for the linear regression example.

First, download a dataset from the UCI ML Repository to disk for subsequent analysis (caching the file on disk is good practice, as it avoids unnecessary load on the UCI server, and allows running the code off-line):

import scalaglm._
import breeze.linalg._

val url = "http://archive.ics.uci.edu/ml/machine-learning-databases/00291/airfoil_self_noise.dat"
val fileName = "self-noise.csv"

// download the file to disk if it hasn't been already
val file = new java.io.File(fileName)
if (!file.exists) {
  val s = new java.io.PrintWriter(file)
  val data = scala.io.Source.fromURL(url).getLines
  data.foreach(l => s.write(l.trim.
    split('\t').filter(_ != "").
    mkString("", ",", "\n")))
  s.close
}

Once we have a CSV file on disk, we can load it up and look at it.

val mat = csvread(new java.io.File(fileName))
// mat: breeze.linalg.DenseMatrix[Double] =
// 800.0    0.0  0.3048  71.3  0.00266337  126.201
// 1000.0   0.0  0.3048  71.3  0.00266337  125.201
// 1250.0   0.0  0.3048  71.3  0.00266337  125.951
// ...
println("Dim: " + mat.rows + " " + mat.cols)
// Dim: 1503 6
val figp = Utils.pairs(mat, List("Freq", "Angle", "Chord", "Velo", "Thick", "Sound"))
// figp: breeze.plot.Figure = breeze.plot.Figure@37718125

We can then regress the response in the final column on the other variables.

val y = mat(::, 5) // response is the final column
// y: DenseVector[Double] = DenseVector(126.201, 125.201, ...
val X = mat(::, 0 to 4)
// X: breeze.linalg.DenseMatrix[Double] =
// 800.0    0.0  0.3048  71.3  0.00266337
// 1000.0   0.0  0.3048  71.3  0.00266337
// 1250.0   0.0  0.3048  71.3  0.00266337
// ...
val mod = Lm(y, X, List("Freq", "Angle", "Chord", "Velo", "Thick"))
// mod: scalaglm.Lm =
// Lm(DenseVector(126.201, 125.201, ...
mod.summary
// Estimate	 S.E.	 t-stat	p-value		Variable
// ---------------------------------------------------------
// 132.8338	 0.545	243.866	0.0000 *	(Intercept)
//  -0.0013	 0.000	-30.452	0.0000 *	Freq
//  -0.4219	 0.039	-10.847	0.0000 *	Angle
// -35.6880	 1.630	-21.889	0.0000 *	Chord
//   0.0999	 0.008	12.279	0.0000 *	Velo
// -147.3005	15.015	-9.810	0.0000 *	Thick
// Residual standard error:   4.8089 on 1497 degrees of freedom
// Multiple R-squared: 0.5157, Adjusted R-squared: 0.5141
// F-statistic: 318.8243 on 5 and 1497 DF, p-value: 0.00000
val fig = mod.plots
// fig: breeze.plot.Figure = breeze.plot.Figure@60d7ebb0

There is a .predict method for generating point predictions (and standard errors) given a new model matrix, and fitting GLMs is very similar – these things are covered in the quickstart guide for the library.

Summary

scala-glm is a small Scala library built on top of the Breeze numerical library which enables simple and convenient regression modelling in Scala. It is reasonably well documented and usable in its current form, but I intend to gradually add additional features according to demand as time permits.

Advertisements

Statistics for Big Data

Doctoral programme in cloud computing for big data

I’ve spent much of this year working to establish our new EPSRC Centre for Doctoral Training in Cloud Computing for Big Data, which partly explains the lack of posts on this blog in recent months. The CDT is now established, with 11 students in the first cohort, and we have begun recruiting for the second cohort, to start in September 2015. We admit roughly equal numbers of students from a Computing Science and Mathematics/Statistics background, and provide an intensive programme of inter-disciplinary training in the first (of four) years. After initial induction and cohort team-building events, the programme begins with 8 weeks of intensive bespoke training developed especially for the CDT students. For the first two weeks the cohort is split into two streams for an initial crash course. Students from a M/S background receive basic training in CS and programming, with an emphasis on object-oriented programming in Java. Students from a CS background get basic training in M/S, emphasising elementary probability and statistics and basic linear algebra. From the start of the third week the entire cohort is trained together. This whole cohort training begins with two 6 week courses running in parallel. One is Programming for big data, concentrating on R programming, Java, databases, and software development tools and techniques. The other course is Statistics for big data, a course that I developed and taught, and the main topic of this post. After the initial 8 weeks of specialist training, the students next take courses on Cloud Computing and Machine learning followed by Big data analytics and Time series data, and finally courses on Research skills and Professional skills, which run along side a Group project.

Statistics for big data

I have found it an interesting challenge to try and put together a 15 credit (150 hours of student effort) course to start from a basic level of statistical knowledge and cover most of the important concepts and methods necessary for modelling and analysis of large and complex data sets. Although the course was not about Big Data per se, it emphasised scalability in general, and exploitation of linearity in particular. Given the mixed levels of mathematical sophistication of the cohort I felt it important to start off with a practical non-technical introduction. For this I covered the first 6 chapters of Introduction to Statistical Learning with R. This provided the students with a basic grounding in statistical modelling ideas, and practical skills in working with data in R. This book is excellent as a non-technical introduction, but is missing the underpinning mathematical and computational details necessary for understanding how to implement such methods in practice. I’ve not found Elements of Statistical Learning to be very suitable as a student text, so I revised, expanded and updated my old multivariate notes to produce a new set of notes on Multivariate Data Analysis using R (I discussed an early version of these notes in a previous post). These notes emphasise the role of numerical linear algebra for solving problems of linear statistical inference in an efficient and numerically stable manner. In particular, they illustrate the use of different matrix factorisations, such as Cholesky, QR, SVD, as well as spectral decompositions, for constructing efficient solutions. Although some of the students with a weaker mathematical background struggled with some of the more technical topics, the associated practical sessions illustrated how the techniques are used in practice.

After filling in a few additional topics of frequentist linear statistical inference (including ANOVA, contrasts, missing data, and experimental design), we moved on to computational Bayesian inference. For the introductory material, I used a variety of sources, including Bayesian reasoning and machine learning, and some introductory material from my own textbook, Stochastic modelling for systems biology. The emphasis for this part of the course was the use of flexible Bayesian hierarchical modelling using MCMC. The primary software tool used was JAGS, used via the rjags package from R. For more advanced material on Bayesian computation, I made substantial use of various posts on this blog, as well as some other material we use for teaching Bayesian inference as part of our undergraduate programmes. As for the first part of the course, the emphasis was on practical modelling and data analysis rather than theory. A few of the computer practicals from the Bayesian part of the course are on-line. I may re-write one or two of these practicals as blog posts in due course. Although the emphasis was on flexible modelling, we did touch on efficiency and exploitation of linearity, and I included an extra Chapter on linear Bayesian inference in my multivariate notes in this context. I rounded off this part of the course with a lecture on parallelisation of Monte Carlo algorithms, along the lines of my BIRS lecture on that topic.

The course finished with a group data analysis project, concerned with linear modelling and variable selection for a fairly large heterogeneous data set containing missing data, using both frequentist and Bayesian approaches. The course has just finished, and I’m reasonably happy with how it has gone, but I’ll reflect on it for a couple of weeks and get some feedback before deciding on some revisions to make before delivering it again for the new cohort next year.

Gibbs sampling a Gaussian Markov random field (GMRF) using Java

Introduction

As I’ve explained previously, I’m gradually coming around to the idea of using Java for the development of MCMC codes, and I’m starting to build up a collection of simple examples for getting started. One of the advantages of Java is that it includes a standard cross-platform GUI library. This might not seem like the most important requirement for MCMC, but can actually be very handy in several contexts, particularly for monitoring convergence. One obvious context is that of image analysis, where it can be useful to monitor image reconstructions as the sampler is running. In this post I’ll show three very small simple Java classes which together provide an application for running a Gibbs sampler on a (non-stationary, unconditioned) Gaussian Markov random field.

The model is essentially that the distribution of each pixel is defined intrinsically, dependent only on its four nearest neighbours on a rectangular lattice, and here the distribution will be Gaussian with mean equal to the sample mean of the four neighbouring pixels and a fixed (unit) variance. On its own this isn’t especially useful, but it is a key component of many image analysis applications.

A simple Java implementation

We will start with the class MrfApp containing the main method for the application:

MrfApp.java

import java.io.*;
class MrfApp {
    public static void main(String[] arg)
	throws IOException
    {
	Mrf mrf;
	System.out.println("started program");
	mrf=new Mrf(800,600);
	System.out.println("created mrf object");
	mrf.update(1000);
	System.out.println("done updates");
	mrf.saveImage("mrf.png");
	System.out.println("finished program");
	mrf.frame.dispose();
	System.exit(0);
    }
}

Hopefully this code is largely self-explanatory, but relies on a class called Mrf which contains all of the logic associated with the GMRF.

Mrf.java

import java.io.*;
import java.util.*;
import java.awt.image.*;
import javax.swing.*;
import javax.imageio.ImageIO;


class Mrf 
{
    int n,m;
    double[][] cells;
    Random rng;
    BufferedImage bi;
    WritableRaster wr;
    JFrame frame;
    ImagePanel ip;
    
    Mrf(int n_arg,int m_arg)
    {
	n=n_arg;
	m=m_arg;
	cells=new double[n][m];
	rng=new Random();
	bi=new BufferedImage(n,m,BufferedImage.TYPE_BYTE_GRAY);
	wr=bi.getRaster();
	frame=new JFrame("MRF");
	frame.setSize(n,m);
	frame.add(new ImagePanel(bi));
	frame.setVisible(true);
    }
    
    public void saveImage(String filename)
	throws IOException
    {
	ImageIO.write(bi,"PNG",new File(filename));
    }
    
    public void updateImage()
    {
	double mx=-1e+100;
	double mn=1e+100;
	for (int i=0;i<n;i++) {
	    for (int j=0;j<m;j++) {
		if (cells[i][j]>mx) { mx=cells[i][j]; }
		if (cells[i][j]<mn) { mn=cells[i][j]; }
	    }
	}
	for (int i=0;i<n;i++) {
	    for (int j=0;j<m;j++) {
		int level=(int) (255*(cells[i][j]-mn)/(mx-mn));
		wr.setSample(i,j,0,level);
	    }
	}
	frame.repaint();
    }
    
    public void update(int num)
    {
	for (int i=0;i<num;i++) {
	    updateOnce();
	}
    }
    
    private void updateOnce()
    {
	double mean;
	for (int i=0;i<n;i++) {
	    for (int j=0;j<m;j++) {
		if (i==0) {
		    if (j==0) {
			mean=0.5*(cells[0][1]+cells[1][0]);
		    } 
		    else if (j==m-1) {
			mean=0.5*(cells[0][j-1]+cells[1][j]);
		    } 
		    else {
			mean=(cells[0][j-1]+cells[0][j+1]+cells[1][j])/3.0;
		    }
		}
		else if (i==n-1) {
		    if (j==0) {
			mean=0.5*(cells[i][1]+cells[i-1][0]);
		    }
		    else if (j==m-1) {
			mean=0.5*(cells[i][j-1]+cells[i-1][j]);
		    }
		    else {
			mean=(cells[i][j-1]+cells[i][j+1]+cells[i-1][j])/3.0;
		    }
		}
		else if (j==0) {
		    mean=(cells[i-1][0]+cells[i+1][0]+cells[i][1])/3.0;
		}
		else if (j==m-1) {
		    mean=(cells[i-1][j]+cells[i+1][j]+cells[i][j-1])/3.0;
		}
		else {
		    mean=0.25*(cells[i][j-1]+cells[i][j+1]+cells[i+1][j]
			       +cells[i-1][j]);
		}
		cells[i][j]=mean+rng.nextGaussian();
	    }
	}
	updateImage();
    }
    
}

This class contains a few simple methods for creating and updating the GMRF, and also for maintaining and updating a graphical view of the GMRF as the sampler is running. The Gibbs sampler update itself is encoded in the final method, updateOnce, and most of the code is to deal with edge and corner cases (in the literal rather than metaphorical sense!). This is called repeatedly by the method update for the required number of iterations. At the end of each iteration, the method updateOnce triggers updateImage which updates the image associated GMRF. The GMRF itself is stored in a 2-dimensional array of doubles, but an image pixel typically consists of a grayscale value represented by an unsigned byte – that is, an integer from 0 to 255. So updateImage scans through the GMRF to find the maximum and minimum values and then maps the GMRF values onto the 0 to 255 scale. The image itself is set up by the constructor method, Mrf. This class relies on an additional class called ImagePanel, which is a simple GUI panel for displaying images:

ImagePanel.java

import java.awt.*;
import java.awt.image.*;
import javax.swing.*;

class ImagePanel extends JPanel {

	protected BufferedImage image;

	public ImagePanel(BufferedImage image) {
		this.image=image;
		Dimension dim=new Dimension(image.getWidth(),image.getHeight());
		setPreferredSize(dim);
		setMinimumSize(dim);
		revalidate();
		repaint();
	}

	public void paintComponent(Graphics g) {
		g.drawImage(image,0,0,this);
	}

}

This completes the application, which can be compiled and run from the command line with

javac *.java
java MrfApp

This should compile the code and run the application, which will show a GMRF updating for 1000 iterations. When the 1000 iterations are complete, the application writes the final image to a file and then quits.

Using Parallel COLT

The above classes are very convenient, as they should work with any standard Java installation. However, in more complex scenarios, it is likely that a math library such as Parallel COLT will be required. In this case it will make sense to make use of features in the COLT library, such as random number generators and 2d matrix objects. We can adapt the above application by replacing the MrfApp and Mrf classes with the following versions (the ImagePanel class remains unchanged):

MrfApp.java

import java.io.*;
import cern.jet.random.tdouble.engine.*;

class MrfApp {

    public static void main(String[] arg)
	throws IOException
    {
	Mrf mrf;
	int seed=1234;
	System.out.println("started program");
        DoubleRandomEngine rngEngine=new DoubleMersenneTwister(seed);
	mrf=new Mrf(800,600,rngEngine);
	System.out.println("created mrf object");
	mrf.update(1000);
	System.out.println("done updates");
	mrf.saveImage("mrf.png");
	System.out.println("finished program");
	mrf.frame.dispose();
	System.exit(0);
    }

}

Mrf.java

import java.io.*;
import java.util.*;
import java.awt.image.*;
import javax.swing.*;
import javax.imageio.ImageIO;
import cern.jet.random.tdouble.*;
import cern.jet.random.tdouble.engine.*;
import cern.colt.matrix.tdouble.impl.*;

class Mrf 
{
    int n,m;
    DenseDoubleMatrix2D cells;
    DoubleRandomEngine rng;
    Normal rngN;
    BufferedImage bi;
    WritableRaster wr;
    JFrame frame;
    ImagePanel ip;
    
    Mrf(int n_arg,int m_arg,DoubleRandomEngine rng)
    {
	n=n_arg;
	m=m_arg;
	cells=new DenseDoubleMatrix2D(n,m);
	this.rng=rng;
	rngN=new Normal(0.0,1.0,rng);
	bi=new BufferedImage(n,m,BufferedImage.TYPE_BYTE_GRAY);
	wr=bi.getRaster();
	frame=new JFrame("MRF");
	frame.setSize(n,m);
	frame.add(new ImagePanel(bi));
	frame.setVisible(true);
    }
    
    public void saveImage(String filename)
	throws IOException
    {
	ImageIO.write(bi,"PNG",new File(filename));
    }
    
    public void updateImage()
    {
	double mx=-1e+100;
	double mn=1e+100;
	for (int i=0;i<n;i++) {
	    for (int j=0;j<m;j++) {
		if (cells.getQuick(i,j)>mx) { mx=cells.getQuick(i,j); }
		if (cells.getQuick(i,j)<mn) { mn=cells.getQuick(i,j); }
	    }
	}
	for (int i=0;i<n;i++) {
	    for (int j=0;j<m;j++) {
		int level=(int) (255*(cells.getQuick(i,j)-mn)/(mx-mn));
		wr.setSample(i,j,0,level);
	    }
	}
	frame.repaint();
    }
    
    public void update(int num)
    {
	for (int i=0;i<num;i++) {
	    updateOnce();
	}
    }
    
    private void updateOnce()
    {
	double mean;
	for (int i=0;i<n;i++) {
	    for (int j=0;j<m;j++) {
		if (i==0) {
		    if (j==0) {
			mean=0.5*(cells.getQuick(0,1)+cells.getQuick(1,0));
		    } 
		    else if (j==m-1) {
			mean=0.5*(cells.getQuick(0,j-1)+cells.getQuick(1,j));
		    } 
		    else {
			mean=(cells.getQuick(0,j-1)+cells.getQuick(0,j+1)+cells.getQuick(1,j))/3.0;
		    }
		}
		else if (i==n-1) {
		    if (j==0) {
			mean=0.5*(cells.getQuick(i,1)+cells.getQuick(i-1,0));
		    }
		    else if (j==m-1) {
			mean=0.5*(cells.getQuick(i,j-1)+cells.getQuick(i-1,j));
		    }
		    else {
			mean=(cells.getQuick(i,j-1)+cells.getQuick(i,j+1)+cells.getQuick(i-1,j))/3.0;
		    }
		}
		else if (j==0) {
		    mean=(cells.getQuick(i-1,0)+cells.getQuick(i+1,0)+cells.getQuick(i,1))/3.0;
		}
		else if (j==m-1) {
		    mean=(cells.getQuick(i-1,j)+cells.getQuick(i+1,j)+cells.getQuick(i,j-1))/3.0;
		}
		else {
		    mean=0.25*(cells.getQuick(i,j-1)+cells.getQuick(i,j+1)+cells.getQuick(i+1,j)
			       +cells.getQuick(i-1,j));
		}
		cells.setQuick(i,j,mean+rngN.nextDouble());
	    }
	}
	updateImage();
    }
    
}

Again, the code should be reasonably self explanatory, and will compile and run in the same way provided that Parallel COLT is installed and in your classpath. This version runs approximately twice as fast as the previous version on all of the machines I’ve tried it on.

Reference

I have found the following book very useful for understanding how to work with images in Java:

Hunt, K.A. (2010) The Art of Image Processing with Java, A K Peters/CRC Press.

Multivariate data analysis (using R)

I’ve been very quiet on-line in the last few months, due mainly to the fact that I’ve been writing a new undergraduate course on multivariate data analysis. Although there are many books and on-line notes on the general topic of multivariate statistics, I wanted to do something a little bit different from any text I have yet discovered. First, I wanted to have a strong emphasis on using techniques in practice on example data sets of reasonable size. For this, I found Hastie et al (2009) to be very useful, as it covered some interesting example data sets which have been bundled in the CRAN R package, ElemStatLearn. I used several of the data sets from this package as running examples throughout the course. In fact my initial plan was to use Hastie et al as the main course text, but it turned out that this text was in some places overly technical and in many places far too terse to be good as an undergraduate text. I would still recommend the book for researchers who want a good overview of the interface between statistics and machine learning, but with hindsight I’m not convinced it is ideal for typical statistics undergraduate students.

I also wanted to have a strong emphasis on numerical linear algebra as the basis for multivariate statistical computation. Again, this is a bit different from “old school” multivariate statistics (which reminds me, John Marden has produced a great text available freely on-line on old school multivariate analysis, which isn’t quite as “old school” as the title might suggest). I wanted to spend some time talking about linear systems and matrix factorisations, explaining, for example how the LU decomposition, the Cholesky factorisation and the QR factorisations are related, and why the latter two are both fundamental to multivariate data analysis, and how the singular value decomposition (SVD) is related to the spectral decomposition, and why it is generally better to construct principal components from the SVD of the centred data matrix than the eigen-decomposition of the sample variance matrix, etc. These sorts of topics are not often covered in undergraduate statistics courses, but they are crucial to understanding how to analyse large multivariate data sets in a numerically stable way.

I also wanted to downplay distribution theory as much as possible, as multivariate distribution theory is quite difficult, and not necessary for understanding most of the essential concepts in multivariate data analysis. Also, it is not obviously very useful. Essentially all introductory courses are based around the multivariate normal distribution, but I have yet to see a real non-trivial multivariate data set for which an assumption of multivariate normality is appropriate. Consequently I delayed the introduction of the multivariate normal until well into the course, and didn’t bother with the Wishart distribution, or testing for multivariate normality. Like much frequentist testing, it is really just a matter of seeing if you have yet collected a large enough sample to reject the null hypothesis – I just don’t see the point (null)!

Finally, I wanted to use R to illustrate all of the methods in practice as they were introduced. We use R throughout our undergraduate statistics programme, and I think it is a good language for learning about statistical methods, algorithms and concepts. In most cases I begin by showing how to carry out analyses using “elementary” operations (such as matrix manipulations), and then go on to show how to accomplish the same task more simply using higher-level R functions and packages. Again, I think it really helps understanding to first see the mathematical description directly translated into computer code before jumping to high-level data analysis functions.

There are several aspects of the course that I would like to distil out into self-contained blog posts, but I have a busy summer schedule, and a couple of other things I want to write about before I’ll have a chance to get around to it, so in the mean time, anyone interested is welcome to download a copy of the course notes (PDF, with hyperlinks). This is the student version, containing gaps, but the gaps mainly correspond to bits of standard theory and examples designed to be worked through by hand. All of the essential theory and context and all of the R examples are present in this version of the notes. There are seven chapters: Introduction to multivariate data; PCA and matrix factorisations; Inference, the MVN and multivariate regression; Cluster analysis and unsupervised learning; Discrimination and classification; Graphical modelling; Variable selection and multiple testing.