A Dirichlet Process Implementation in Julia

5 minute read

The following code closely follows the lecture notes and examples found here which provide excellent reading on Dirichlet Processes. This post is less about DP’s themselves and more about what an implementation in Julia might look like. The example here is the “Traffic” dataset from the “MASS” library for which the description reads

An experiment was performed in Sweden in 1961-2 to assess the effect of a speed limit on the motorway accident rate. The experiment was conducted on 92 days in each year, matched so that day j in 1962 was comparable to day j in 1961. On some days the speed limit was in effect and enforced, while on other days there was no speed limit and cars tended to be driven faster. The speed limit days tended to be in contiguous blocks.

Let’s load the dataset into Julia and inspect histograms of the number of accidents each day when the speed limit is and is not enforced.

using RDatasets
using Gadfly
using Distributions
using Iterators

speed=dataset("MASS","Traffic")
limityes=speed[speed[:Limit].=="yes",:Y]
nyes=length(limityes)
limitno=speed[speed[:Limit].=="no",:Y]
nno=length(limitno)
plot(speed, x="Y", color="Limit", Geom.histogram(bincount=20,density=true))
Y -100 -80 -60 -40 -20 0 20 40 60 80 100 120 140 160 180 -80 -75 -70 -65 -60 -55 -50 -45 -40 -35 -30 -25 -20 -15 -10 -5 0 5 10 15 20 25 30 35 40 45 50 55 60 65 70 75 80 85 90 95 100 105 110 115 120 125 130 135 140 145 150 155 160 -100 0 100 200 -80 -75 -70 -65 -60 -55 -50 -45 -40 -35 -30 -25 -20 -15 -10 -5 0 5 10 15 20 25 30 35 40 45 50 55 60 65 70 75 80 85 90 95 100 105 110 115 120 125 130 135 140 145 150 155 160 no yes Limit -2.0 -1.5 -1.0 -0.5 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 -1.50 -1.45 -1.40 -1.35 -1.30 -1.25 -1.20 -1.15 -1.10 -1.05 -1.00 -0.95 -0.90 -0.85 -0.80 -0.75 -0.70 -0.65 -0.60 -0.55 -0.50 -0.45 -0.40 -0.35 -0.30 -0.25 -0.20 -0.15 -0.10 -0.05 0.00 0.05 0.10 0.15 0.20 0.25 0.30 0.35 0.40 0.45 0.50 0.55 0.60 0.65 0.70 0.75 0.80 0.85 0.90 0.95 1.00 1.05 1.10 1.15 1.20 1.25 1.30 1.35 1.40 1.45 1.50 1.55 1.60 1.65 1.70 1.75 1.80 1.85 1.90 1.95 2.00 2.05 2.10 2.15 2.20 2.25 2.30 2.35 2.40 2.45 2.50 2.55 2.60 2.65 2.70 2.75 2.80 2.85 2.90 2.95 3.00 -2 0 2 4 -1.5 -1.4 -1.3 -1.2 -1.1 -1.0 -0.9 -0.8 -0.7 -0.6 -0.5 -0.4 -0.3 -0.2 -0.1 0.0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9 2.0 2.1 2.2 2.3 2.4 2.5 2.6 2.7 2.8 2.9 3.0

A natural question to ask is what is the posterior probability that the number of accidents is less on a day when the speed limit is enforced compared to when it is. To answer this one could build a statistical model using a Dirichlet process.

Dirichlet Process Model Specification

Model Specification

and c a normalizing constant so that the weights sum to one.

A direct scan gibbs sampler could then be implemented as follows

function randDP(loglikelihood,logmarginallikelihood,posterior,Y,a,Θin)
    Θ=deepcopy(Θin)
    for i = 1:length(Y)
        weights = map(x -> loglikelihood(Y[i],x),Θ)
        weights[i] = logmarginallikelihood(Y[i])
        weights = exp(weights-maximum(weights))
        weights[i] = weights[i]*a
        weights = weights/sum(weights)
        select = rand(Categorical(weights))
        Θ[i] = (select == i ) ? rand(posterior(Y[i])) : Θ[select]
    end
    Θ
end

For the swedish speed limit dataset, natural assumptions would be that the function g is a Poisson pmf and the base measure is Gamma(b_{1},b_{2})

Gibbs Sampling

function partial(f,a...)
    (b...) -> f(a...,b...)
end

b1=2
b2=0.1
a=1
PoissonGammaDP = partial(randDP,
                        (y,θ)->logpdf(Poisson(θ),y),
                        y->logpdf(NegativeBinomial(b1,b2/(b2+1)),y),
                        y->Gamma(b1+y,b2+1))

mcmcyes = iterate(partial(PoissonGammaDP,
                        limityes,a),
                    ones(length(limityes)))

mcmcno = iterate(partial(PoissonGammaDP,
                        limitno,a),
                    ones(length(limitno)))
drawsyes=collect(takenth(take(drop(mcmcyes,1000),10000),10))
drawsno=collect(takenth(take(drop(mcmcno,1000),10000),10))
distsyes=map( x -> MixtureModel([MixtureModel(map(y -> Poisson(y),x)),NegativeBinomial(b1,b2/(b2+1))],[nyes/(a+nyes), a/(a+nyes)]), drawsyes)
distsno=map( x -> MixtureModel([MixtureModel(map(y -> Poisson(y),x)),NegativeBinomial(b1,b2/(b2+1))],[nno/(a+nno), a/(a+nno)]), drawsno)
mean(x -> x[1]<x[2],zip(map(x->rand(x),distsyes),map(x->rand(x),distsno)))
0.603

The iterate function from the Iterators.jl package forms a lazy evaluated list of MCMC iterates. In a functional style the first 1000 iterates can be dropped as a burnin. One can then “take” 10000 iterates from the chain (allbeit lazily) and therefrom take every 10’th iterate using the takenth function. The collect function then forces the evaluation of the lazily evaluated list. Each iterate consists of an $\theta$ array which can be used to form a posterior predictive distribution. Hence one can map a function that takes a $\theta$ array and returns a posterior predictive distribution over the mcmc draws to get an array of posterior predictive distributions, which is achieved in distsyes and distsno respectively. One can then generate random variates from each element of these arrays in turn, zip them together, and supply a predicate which evaluates true whenever the second element in a pair is greater than the first. Computing the mean of this boolean array then gives a monte carlo estimate of the probability that the number of accidents on a day when there is no speed limit is greater than the number of accidents on a day when there is no speed limit as 0.603.

Visualization

Here’s a function to plot the mean posterior predictive pmf’s with 95% credible intervals.

function mcmcdataframe(draws,gitter,name)
  dists=map( x -> MixtureModel([MixtureModel(map(y -> Poisson(y),x)),NegativeBinomial(b1,b2/(b2+1))],[nyes/(a+nyes), a/(a+nyes)]), draws)
  gridmatrix=hcat(map(z -> pdf(z,gitter),dists)...)'
  meanfunction=map(grid -> mean(gridmatrix[:,grid]),gitter)
  upperfunction=map(grid -> quantile(gridmatrix[:,grid],0.975),gitter)
  lowerfunction=map(grid -> quantile(gridmatrix[:,grid],0.025),gitter)
  df_DP = DataFrame(
    x=gitter,
    y=meanfunction,
    ymin=lowerfunction,
    ymax=upperfunction,
    Speed_Limit=name
  )
end

df_mcmc=vcat(mcmcdataframe(drawsno,1:60,"No"),mcmcdataframe(drawsyes,1:60,"Yes"))

plot(df_mcmc, x=:x, y=:y, ymin=:ymin, ymax=:ymax, color=:Speed_Limit, Geom.line, Geom.ribbon)
x -70 -60 -50 -40 -30 -20 -10 0 10 20 30 40 50 60 70 80 90 100 110 120 130 -60 -58 -56 -54 -52 -50 -48 -46 -44 -42 -40 -38 -36 -34 -32 -30 -28 -26 -24 -22 -20 -18 -16 -14 -12 -10 -8 -6 -4 -2 0 2 4 6 8 10 12 14 16 18 20 22 24 26 28 30 32 34 36 38 40 42 44 46 48 50 52 54 56 58 60 62 64 66 68 70 72 74 76 78 80 82 84 86 88 90 92 94 96 98 100 102 104 106 108 110 112 114 116 118 120 -100 0 100 200 -60 -55 -50 -45 -40 -35 -30 -25 -20 -15 -10 -5 0 5 10 15 20 25 30 35 40 45 50 55 60 65 70 75 80 85 90 95 100 105 110 115 120 No Yes Speed_Limit -0.10 -0.08 -0.06 -0.04 -0.02 0.00 0.02 0.04 0.06 0.08 0.10 0.12 0.14 0.16 0.18 -0.080 -0.075 -0.070 -0.065 -0.060 -0.055 -0.050 -0.045 -0.040 -0.035 -0.030 -0.025 -0.020 -0.015 -0.010 -0.005 0.000 0.005 0.010 0.015 0.020 0.025 0.030 0.035 0.040 0.045 0.050 0.055 0.060 0.065 0.070 0.075 0.080 0.085 0.090 0.095 0.100 0.105 0.110 0.115 0.120 0.125 0.130 0.135 0.140 0.145 0.150 0.155 0.160 0.165 -0.1 0.0 0.1 0.2 -0.080 -0.075 -0.070 -0.065 -0.060 -0.055 -0.050 -0.045 -0.040 -0.035 -0.030 -0.025 -0.020 -0.015 -0.010 -0.005 0.000 0.005 0.010 0.015 0.020 0.025 0.030 0.035 0.040 0.045 0.050 0.055 0.060 0.065 0.070 0.075 0.080 0.085 0.090 0.095 0.100 0.105 0.110 0.115 0.120 0.125 0.130 0.135 0.140 0.145 0.150 0.155 0.160 0.165 y

Leave a Comment