Mathsy Coding

Hierarchical Bayesian Models

A look at hierarchical models, a kind of model that allows you to share information between different clusters of data. We'll start with a simple example and then apply it to a real-world situation.

Recently I’ve been learning about multilevel models. These are a kind of model where, by clustering the data in some way, you can ‘share’ information between different pieces of the model. That makes them an excellent way to get the most out of the often limited data that you might have available. Another advantage of these models is that it provides a way to account for additional uncertainty, in much the same way that a mixture model does. Let’s go through and see how to construct such a model (specifically, a varying-intercepts model) and cap it off with a real-world application!

Example - Heights

Let’s start off with a relatively simple example. We’re going to look at the simulated heights of people in different countries. In this situation, there is a natural hierarchical nature to the data: each height belongs to a person and each student belongs to a country. We might naturally expect that the country that a person is part of will have some influence on their score. There are a few ways that we could approach this:

  1. Ignore the countries We could just ignore the fact that the people are in different countries and group them all together when making our model. This approach is called complete pooling.
  2. Fit separate models Another approach would be to fit a separate model to each country completely independently. This is the no pooling approach.

At least in this case, both of those seem unsatisfactory. It seems very likely that there is some similarity between the countries - if you learn the mean for one country, that might influence your belief about the next country you visit. So how can we incorporate that information? Essentially, we’re going to fit a model for each country, but each one will share a common prior. This prior will be influenced by the data, and since it is shared between each country that will allow the information from one country to influence the model for the others.

This approach, where we share some information between the different clusters (countries, in this case) is called partial pooling. In the case where there is a hierarchical structure to the data, we call it a hierarchical model.

Let’s try this on some simulated data to see how it works before we apply it to some real data on diploma exam results in Alberta.

Simulation

First let’s set up the situation. For simplicity, we’ll assume that each person’s height is drawn from a normal distribution, and that the mean of that distribution is specific to each country. However, we’ll also assume that the means for the different countries are drawn from a common distribution as well. This will show up in the model as a shared prior.

Let’s use this process to generate some data, and then we’ll use the same model to recover it.

# Set up the environment
library(ggplot2)
library(cmdstanr)
library(posterior)
library(readxl)

options(repr.plot.width = 17, repr.plot.height = 8)
set.seed(500)

# country parameters
NUM_COUNTRIES <- 4
country_means <- rnorm(NUM_COUNTRIES, 176, 10)
country_sigma <- rexp(1, 1)

# now generate some people
NUM_PEOPLE <- 20
country_data <- data.frame(
    country_index = as.factor(rep(1:NUM_COUNTRIES, each = NUM_PEOPLE)),
    height = rnorm(
        n = NUM_COUNTRIES * NUM_PEOPLE,
        mean = rep(country_means, each = NUM_PEOPLE),
        sd = 3
    )
)

ggplot(country_data, aes(height)) +
    geom_density(aes(group = country_index, colour = country_index)) +
    geom_vline(data = data.frame(country = as.factor(1:NUM_COUNTRIES), mean = country_means), mapping = aes(xintercept = mean, colour = country)) +
    labs(x = "Height (cm)", y = "Density", colour = "Country") +
    theme_bw() +  # Set background to white
    theme(
        panel.grid.major.y = element_line(size = 0.5, linetype = 'dashed', colour = "grey")
    )

png

To recover the values, let’s use the following model.

HiNormal(μi,σ)μiNormal(μCOUNTRY[i],σCOUNTRY)μCOUNTRYNormal(αˉ,2)αˉNormal(176,1)σCOUNTRYExponential(1)σExponential(1)\begin{align*} H_i &\sim \text{Normal}( \mu_i, \sigma ) \\ \mu_i &\sim \text{Normal}(\mu_{\text{COUNTRY}[i]}, \sigma_{\text{COUNTRY}}) \\ \mu_{\text{COUNTRY}} &\sim \text{Normal}(\bar{\alpha}, 2) \\ \bar{\alpha} &\sim \text{Normal}(176, 1) \\ \sigma_{\text{COUNTRY}} &\sim \text{Exponential}(1) \\ \sigma &\sim \text{Exponential}(1) \\ \end{align*}

The way that we’re going to share the information between the countries is through the αˉ\bar\alpha parameter. Because this mean value for the prior for each country must be the same for each of the different countries, it provides a way for that data to flow between them as the model is fit.

Now that we have the data and know the generative process behind it, let’s ensure that we can recover those same values using the model.

stan_model_code <- "
data {
    int<lower = 0> N_COUNTRIES; // number of countries
    int<lower = 0> N; // number of people
    vector[N] heights;
    array[N] int<lower = 0, upper = N_COUNTRIES> country_id; // the country id corresponding to each height
}
parameters {
    vector[N_COUNTRIES] mu_country;
    real<lower = 1e-3> sigma; // sd for the heights
    real<lower = 1e-3> sigma_country; // sd for the country mean values
    real alpha_bar;
}
model {
    heights ~ normal(mu_country[country_id], sigma);
    mu_country ~ normal(alpha_bar, 10);
    alpha_bar ~ normal(176, 1);
    sigma_country ~ exponential(1);
    sigma ~ exponential(1);
}
"
stan_model_file <- write_stan_file(stan_model_code)
stan_model <- cmdstan_model(stan_model_file)

stan_model_data <- list(
    N_COUNTRIES = NUM_COUNTRIES,
    N = nrow(country_data),
    heights = country_data$height,
    country_id = as.integer(country_data$country_index)
)
fit <- stan_model$sample(
    data = stan_model_data,
    chains = 4
)
Running MCMC with 4 sequential chains...

All 4 chains finished successfully.
Mean chain execution time: 0.0 seconds.
Total execution time: 0.5 seconds.
fit$summary()
A draws_summary: 8 × 10
variablemeanmediansdmadq5q95rhatess_bulkess_tail
<chr><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl>
lp__ -144.334351-143.95450001.96315621.7183334-148.04970000-141.8199001.0008461640.3762355.089
mu_country[1] 185.073222 185.06000000.75132260.7405587 183.82495000 186.3201001.0003046575.6823154.603
mu_country[2] 194.750996 194.75700000.74143490.7220262 193.50895000 195.9581001.0015927044.0552985.646
mu_country[3] 185.335148 185.33700000.73850660.6983046 184.07985000 186.5477501.0017197405.3252834.630
mu_country[4] 176.916428 176.91800000.74821340.7650216 175.69300000 178.1540501.0018066358.0992966.653
sigma 3.342929 3.31947000.26878680.2637768 2.93077650 3.8155151.0017396754.8193248.544
sigma_country 1.000606 0.69431150.98737340.7047710 0.04307331 2.9403081.0002653774.2951877.089
alpha_bar 176.364408 176.36550000.97841720.9748095 174.74800000 177.9840501.0005147930.2432686.251

From the summary table, we can see that our posterior values are close to the true values. Let’s see the same thing graphically.

height_draws <- as_draws_matrix(fit)
head(height_draws)
A draws_matrix: 6 × 8 of type dbl
lp__mu_country[1]mu_country[2]mu_country[3]mu_country[4]sigmasigma_countryalpha_bar
1-144.016185.422194.549185.765176.8043.570340.0527758175.646
2-143.700184.574194.249184.553176.3973.268623.6745100177.132
3-145.799183.510195.764184.747176.2203.367220.3739620175.124
4-144.623185.169193.716184.875176.1453.540781.0502200178.216
5-143.933184.849195.680185.868177.6843.211390.9323830174.692
6-144.299185.165193.932184.807176.0493.466841.2301000178.207
country_mean_results <- height_draws[, grep('mu_country', colnames(height_draws))]
results_df <- data.frame(
    country_id = 1:NUM_COUNTRIES,
    mean = apply(country_mean_results, 2, mean),
    lower = apply(country_mean_results, 2, function(row) quantile(row, 0.025)),
    upper = apply(country_mean_results, 2, function(row) quantile(row, 0.975)),
    type = "Posterior"
)

posterior_alpha_bar <- mean(height_draws[, 'alpha_bar'])

ggplot(results_df, aes(country_id, mean)) +
    geom_point(aes(colour = type)) +
    geom_errorbar(aes(ymin = lower, ymax = upper)) +
    geom_point(data = data.frame(country_id = 1:NUM_COUNTRIES, mean = country_means, type = "Actual"), mapping = aes(colour = type)) +
    geom_hline(mapping = aes(yintercept = posterior_alpha_bar), linetype = 'dashed') +
    labs(x = "Country", y = "Height (cm)", colour = "Type") +
    theme_bw() +  # Set background to white
    theme(
        panel.grid.major.y = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
    )

png

One thing to note is that we have a larger error than we would expect if we fit each country individually. The reason for that is that a hierarchical model like this tends to ‘shrink’ estimates towards the mean. You can see this here by the fact that the values of the posterior tend to be closer to the mean than the actual values. While it may seem crazy to deliberately choose a model that will tend to do worse on the data, in fact this is just part of the overfitting / underfitting tradeoff - by biasing estimates to the mean, we will do worse on our training set but better on outside data.

So now that we’ve seen a relatively simple example, let’s try to build our way towards our eventual goal of applying this idea to school test scores.

Example - School Test Scores

Imagine that you have a number of different schools in a school division, and you’re interested in how the students in those schools are doing on a test. We believe that each school has an effect on the students, and we also believe that the school division has an effect on the school. Here we have two nested pools - each student belongs to a school, and each school belongs to a district.

Simulation

Imagine that you have nine schools taking part in some sort of standardized exam. Each school has some number of students that are taking the exam, and they each get a grade. The grade they get is a combination of luck, their own properties, and some influence of the school. The influence of the school is itself dependent on the school district. Let’s try to generate some data, then create a model to recover the original values. By doing so, we’re ensuring that our model is fit for purpose.

Since this model is a bit more complex, let’s take it one step at a time.

First, let’s start with a single school. The distribution for the school’s students will be drawn from a beta distribution. The beta is the perfect distribution for this since it’s a distribution of probabilities, which is what we’re looking for. The beta distribution has two parameters, α\alpha and β\beta. The mean is α/(α+β)\alpha / (\alpha + \beta), and the variance is αβ/(α+β)2(α+β+1)\alpha\beta / (\alpha + \beta)^2 (\alpha + \beta + 1). If we’re thinking of the beta distribution as being related to a Binomial random process (i.e. a set of coin flips), then we can think of α\alpha as the number of successes and β\beta as the number of failures. This is not quite correct, but it’s pretty close.

So let’s say that we want the school average to be about 75%. There are a few different values of α\alpha and β\beta that we could choose; we just need the mean α/(α+β)=0.75\alpha / (\alpha + \beta) = 0.75. For instance, α=3\alpha = 3 and β=1\beta = 1 would work. So would any multiple of the those: 6 and 2, 4 and 43\frac{4}{3}, &c.

alpha_base <- 3
beta_base <- 1

multiples <- seq.int(from = 1, to = 4)
plot_df <- data.frame(p = numeric(), density = numeric(), type = character())

probs <- seq(from = 0, to = 1, length.out = 100)
for (multiple in multiples) {
    alpha <- alpha_base * multiple
    beta <- beta_base * multiple
    data <- data.frame(p = probs, density = dbeta(probs, alpha, beta), type = paste("alpha = ", alpha, " beta = ", beta))
    plot_df <- rbind(plot_df, data)
}

ggplot(plot_df, aes(p, density, group = type, colour = type)) +
    geom_line() +
    labs(x = "p", y = "Density", colour = "Parameter Values") +
    scale_x_continuous(labels = scales::percent_format(accuracy = 1)) +
    theme_bw() +  # Set background to white
    theme(
        panel.grid.major = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
    )

png

Note that although the mean for each of these is at 0.750.75, that’s not the same as having the peak value be there.

If we’re thinking forward to making our model, it’s not clear to me exactly what we should be thinking of the parameters α\alpha and β\beta as representing. Intuitively, I’d prefer to think in terms of the mean and variance (or even more generally, the centre and the spread) of the expected data. Luckily, it turns out that there is a reparameterization of the beta in terms of these:

μ=αα+β[The mean - centre]ν=α+β[The ‘sample size’ - spread]\begin{align*} \mu &= \frac{\alpha}{\alpha + \beta} & [\text{The mean - centre}]\\ \nu &= \alpha + \beta & [\text{The `sample size' - spread}] \end{align*}

Solving for α\alpha and β\beta,

α=μνβ=(1μ)ν\begin{align*} \alpha &= \mu\nu \\ \beta &= (1 - \mu)\nu \end{align*}

Let’s recreate the above using this parameterization to convince ourselves that it works. Since we’re fixing our mean, the spread parameter is the only one that needs to change.

mu <- 0.75
spread <- 4 # 'sample size'

plot_df <- data.frame(p = numeric(), density = numeric(), type = character())

probs <- seq(from = 0, to = 1, length.out = 100)
for (sample_size in seq.int(from = 4, to = 16, by = 4)) {
    alpha <- mu * sample_size
    beta <- (1 - mu) * sample_size
    data <- data.frame(p = probs, density = dbeta(probs, alpha, beta), type = paste("mu = ", mu, " nu = ", sample_size))
    plot_df <- rbind(plot_df, data)
}

ggplot(plot_df, aes(p, density, group = type, colour = type)) +
    geom_line() +
    labs(x = "p", y = "Density", colour = "Parameter Values") +
    scale_x_continuous(labels = scales::percent_format(accuracy = 1)) +
    theme_bw() +  # Set background to white
    theme(
        panel.grid.major = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
    )

png

Great! So now we have our parameterization. Let’s make sure that we can model at least this part of the problem. Let’s say that we have a school where the average test score is 75%, and let’s simulate a class of 20 students.

NUM_STUDENTS <- 20
school_mu <- 0.75
school_nu <- 10 # this is *not* the same as the number of students!

alpha <- school_mu * school_nu
beta <- (1 - school_mu) * school_nu

actual_student_scores <- rbeta(NUM_STUDENTS, alpha, beta)

p <- seq(from = 0, to = 1, length.out = 100)
line_data <- data.frame(p = p, density = dbeta(p, alpha, beta))
student_data <- data.frame(p = actual_student_scores, density = dbeta(actual_student_scores, alpha, beta))
ggplot() +
    geom_line(data = line_data, mapping = aes(p, density)) +
    geom_point(data = student_data, mapping = aes(p, density)) +
    scale_x_continuous(labels = scales::percent_format(accuracy = 1)) +
    theme_bw() +
    theme(
        panel.grid.major = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
    )

png

So here we can see the theoretical distribution of the students’ scores along with the actual scores. Now let’s use a model to recover those values!

sBeta(α,β)α=μνβ=(1μ)νμbeta(12,4)νNormal(10,1)\begin{align*} \text{s} &\sim \text{Beta}(\alpha, \beta) \\ \alpha &= \mu * \nu \\ \beta &= (1 - \mu) * \nu \\ \mu &\sim \text{beta}(12, 4) \\ \nu &\sim \text{Normal}(10, 1) \end{align*}

How did we arrive at those priors? We ran some simulations of the model with different parameters to find ones that were reasonable. In a real model we would do this before looking at the data, but in this case that’s not really possible. We’ll just pretend that we don’t know what the data should look like. Also, when checking a model like this, it’s not unreasonable to start with a model that reflects the true underlying process - the idea here is that we are ground-truthing it. If the model can’t get back the correct values even when we know the model is correct and the values are close, then it really isn’t fit for purpose.

NUM_SIM_STUDENTS <- 100
nu <- rnorm(NUM_SIM_STUDENTS, 10, 1)
mu <- rbeta(NUM_SIM_STUDENTS, 12, 4)
alpha <- mu * nu
beta <- (1 - mu) * nu
s <- rbeta(NUM_SIM_STUDENTS, alpha, beta)

plot_df <- data.frame(s = s)
ggplot(plot_df, aes(s)) +
    geom_density() +
    labs(x = "Score", y = "Density") +
    coord_cartesian(xlim = c(0, 1)) +
    scale_x_continuous(labels = scales::percent_format(accuracy = 1)) +
    theme_bw() +
    theme(
        panel.grid.major = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
    )

png

ggplot(data.frame(nu = nu), aes(nu)) +
    geom_density() +
    scale_x_continuous(labels = scales::percent_format(accuracy = 1)) +
    labs(x = "Score", y = "Density") +
    theme_bw() +
    theme(
        panel.grid.major = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
    )

png

Before we do anything else, let’s do a prior predictive check. The idea here is to see what kind of values our model is producing before we feed it any data. In large part this should help to constrain our priors - are the data that we’re seeing come out of the model in line with our expectations? Are they plausible? If not, we’ll have to go back and revise our priors. We can actually use Stan to help us with this.

student_score_model_pp_code <- "
data {
    int<lower = 0> N; // number of students
}
parameters {
    real<lower = 0, upper = 1> mu;
    real<lower = 0> nu;
}
 transformed parameters {
    real<lower = 0> a = mu * nu;
    real<lower = 0> b = (1 - mu) * nu;
}
model {
    mu ~ beta(12, 4);
    nu ~ normal(10, 1);
}
generated quantities {
    vector[N] s_sim;
    for (n in 1:N) {
        s_sim[n] = beta_rng(a, b);
    }
}
"

student_score_model_pp_file <- write_stan_file(student_score_model_pp_code)
student_score_pp_model <- cmdstan_model(student_score_model_pp_file)

student_score_model_pp_data <- list(
    N = 100
)

student_score_pp_fit <- student_score_pp_model$sample(
    data = student_score_model_pp_data,
    chains = 4
)
student_score_pp_fit$summary()
Running MCMC with 4 sequential chains...

All 4 chains finished successfully.
Mean chain execution time: 0.0 seconds.
Total execution time: 0.5 seconds.
A draws_summary: 105 × 10
variablemeanmediansdmadq5q95rhatess_bulkess_tail
<chr><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl>
lp__ -7.7472137-7.39619501.10374890.7291056-10.0643700-6.74995401.00051781726.7621776.441
mu 0.7536242 0.76426950.10421460.1057909 0.5686431 0.90689321.00152753183.7361685.494
nu 9.987131210.00480001.00971680.9897838 8.292429511.60801501.00087793470.8942249.882
a 7.5237364 7.52778001.28120311.2879643 5.3965645 9.61252551.00068553423.5562280.696
b 2.4633951 2.33608001.08122361.0767531 0.8796699 4.37572401.00142893148.1091744.702
s_sim[1] 0.7519102 0.77713950.16455820.1691661 0.4368936 0.97190490.99973453638.2403666.031
s_sim[2] 0.7511981 0.77662950.16554530.1734783 0.4370754 0.97527691.00192823536.6002732.750
s_sim[3] 0.7534096 0.77961050.16232450.1685798 0.4469912 0.97491231.00001883211.5462890.503
s_sim[4] 0.7516449 0.77729900.16437840.1697799 0.4385267 0.97444250.99983823465.0353140.728
s_sim[5] 0.7500511 0.77592650.16901920.1730068 0.4296476 0.97459451.00112233213.4092759.927
s_sim[6] 0.7546354 0.77894800.16225610.1683270 0.4521205 0.97211891.00165193591.5872984.244
s_sim[7] 0.7550554 0.77885650.16035550.1662313 0.4607637 0.97256811.00039743557.4382779.617
s_sim[8] 0.7549163 0.78154350.16581290.1695472 0.4389618 0.97506770.99967893018.3432864.190
s_sim[9] 0.7557622 0.78245100.16282290.1724256 0.4547386 0.97396961.00008253615.7003202.578
s_sim[10] 0.7580823 0.78584600.16457780.1722047 0.4479410 0.97376890.99978813390.1713311.547
s_sim[11] 0.7526162 0.78051200.16652240.1725220 0.4385346 0.97566221.00063923648.1143135.257
s_sim[12] 0.7547576 0.77699150.16281210.1683196 0.4495085 0.97200770.99970573701.2633590.140
s_sim[13] 0.7549515 0.78273050.16463740.1691936 0.4403020 0.97410461.00092013463.3333247.366
s_sim[14] 0.7562782 0.78254650.16246110.1725020 0.4543448 0.97436931.00075923531.5483317.929
s_sim[15] 0.7555521 0.77977950.16328660.1702054 0.4528257 0.97552370.99965433702.6522749.965
s_sim[16] 0.7554666 0.78036400.16290030.1675064 0.4497360 0.97383680.99977973369.5793193.877
s_sim[17] 0.7496401 0.77863450.16799840.1743945 0.4295800 0.97502431.00063913546.7953577.849
s_sim[18] 0.7557703 0.78096000.16580100.1700186 0.4469792 0.97395231.00060383835.5573256.432
s_sim[19] 0.7559281 0.78230200.16189660.1639392 0.4468340 0.97323760.99966223645.4352824.582
s_sim[20] 0.7528602 0.77908400.16808490.1709890 0.4281760 0.97507861.00017913651.5773719.415
s_sim[21] 0.7523741 0.77619150.16442950.1704575 0.4476327 0.97644950.99971323746.2163333.132
s_sim[22] 0.7561313 0.78553850.16388160.1694990 0.4448916 0.97341621.00095433417.1602777.871
s_sim[23] 0.7499353 0.77681450.16557740.1709453 0.4405739 0.97534091.00014623678.2812928.156
s_sim[24] 0.7507685 0.77703650.16430300.1680290 0.4446655 0.97086470.99962922999.9353179.589
s_sim[25] 0.7539305 0.78099750.16340290.1679719 0.4425114 0.97294351.00060233634.6683420.166
s_sim[71] 0.75231160.78069500.16576780.16874210.43869950.97376481.00132233680.5812971.260
s_sim[72] 0.75837620.78429950.16179390.17360060.45455440.97107390.99969763301.5852840.791
s_sim[73] 0.75567200.78201650.16279360.16843000.44781690.97263971.00126803499.0803208.706
s_sim[74] 0.75471810.78092200.16102010.16444110.45164670.97150661.00001983399.1173014.405
s_sim[75] 0.75333810.77426850.16287410.16894080.44962730.97646321.00065303359.6873042.876
s_sim[76] 0.75324820.77789800.16499070.17233450.44483920.97647511.00011583729.7723561.618
s_sim[77] 0.75406970.77762550.16200040.16874210.44931490.97484280.99997813709.7432924.111
s_sim[78] 0.75523180.77871800.16176090.16990890.45189850.97168031.00023553542.8553354.554
s_sim[79] 0.75137190.77720950.16582030.17664360.43887280.97481161.00067423688.1713517.024
s_sim[80] 0.75541850.77869300.16247980.16633140.45214620.97196801.00101313649.0073431.718
s_sim[81] 0.75295000.78102500.16637390.17385190.44805540.97289770.99977023675.6783117.200
s_sim[82] 0.75849320.78436250.16287150.17178660.44695980.97360291.00072923708.6612922.429
s_sim[83] 0.75515230.78127500.16216800.17059170.45382080.97397640.99983533542.6883206.694
s_sim[84] 0.75174180.77478350.16570260.17129740.44439580.97458131.00039513590.6003007.862
s_sim[85] 0.75464820.77881000.16130260.16977990.45589260.97305131.00031603801.1883413.002
s_sim[86] 0.75675690.78508450.16372290.16127500.43985130.97561851.00082553110.9953298.655
s_sim[87] 0.75391450.78052300.16562450.17111950.44073170.97585640.99990483385.5373236.119
s_sim[88] 0.75321130.77943300.16357240.17416700.44902220.97434991.00080993492.8363293.382
s_sim[89] 0.75177930.78069650.16638390.17423810.43440920.97393651.00052923617.6383039.248
s_sim[90] 0.75369330.77973500.16309730.16664650.44674670.96945010.99971693493.0193535.025
s_sim[91] 0.75737520.78210050.16131370.17015430.45718540.97331861.00018933728.7332855.958
s_sim[92] 0.74990690.77553450.16689600.17235740.44163180.97170481.00021163454.4662728.611
s_sim[93] 0.75445950.77864400.16236720.17240040.45161960.97261071.00081923533.9273457.289
s_sim[94] 0.75784210.78567850.16382960.16983480.44723350.97473760.99991613605.3453350.634
s_sim[95] 0.75303640.77669300.16270190.17025510.45034900.97517471.00012203690.4762992.951
s_sim[96] 0.75444950.77854350.16284620.16919060.45412670.97454910.99997043711.5953040.620
s_sim[97] 0.75313730.78303500.16286630.16497040.44413290.97450941.00055973653.4993222.520
s_sim[98] 0.75504060.78079900.16237960.16874660.44267820.97250740.99978673655.1283620.627
s_sim[99] 0.75456760.78128400.16488590.17109350.44015410.97338610.99955373310.8773100.817
s_sim[100]0.75134840.78022950.16639230.17744350.44261890.97358791.00063653509.8373482.959

From the above summary the values of the parameters mu and nu look correct, as do the values for the simulated student scores. The rhat for each parameter also looks good (close to 1) and the effective sample size (ess_bulk) is also good. There was a minor problem in the fitting where a nan value was generated, but that’s a rare thing and only happens intermittently. Overall, this looks good! Now let’s get the data.

draws <- as_draws_matrix(student_score_pp_fit)
head(draws)
A draws_matrix: 6 × 105 of type dbl
lp__munuabs_sim[1]s_sim[2]s_sim[3]s_sim[4]s_sim[5]s_sim[91]s_sim[92]s_sim[93]s_sim[94]s_sim[95]s_sim[96]s_sim[97]s_sim[98]s_sim[99]s_sim[100]
1 -6.954830.667633 9.972296.657833.3144500.4564350.6641230.7115740.4476430.7342970.4982740.7549390.7269930.4563230.7034680.9693840.7193290.2853990.5781320.768713
2 -8.329900.657733 8.482275.579072.9032000.7680420.8043580.4987660.5656570.6112610.6658320.9343080.7681600.5282550.6809990.4465520.5492910.7034050.6753080.620042
3-10.931000.944694 8.709128.227450.4816700.9997420.9563790.9065800.8955110.9372070.9919360.8117210.9894810.9014620.9827310.9961480.7972900.9648120.7561310.917414
4 -7.774070.88310610.010708.840471.1701900.8275030.8687120.7808070.9228410.9028450.7259120.8145340.8618690.8167700.9098440.9656250.8807520.9531300.9114440.711088
5 -8.188390.899799 9.869498.880560.9889320.9820080.9348490.9718970.9875080.9291940.9965580.8895140.9933640.9451950.7451940.8653540.9194280.8965710.7387420.974191
6 -7.554400.64648011.062207.151483.9107000.6766980.2805900.6494630.6113210.6802220.6322140.4268870.8109640.4087920.7522070.7049220.5561050.7852660.6324280.470659

Now that we have the draws, let’s plot a few different classes of 100 students (s_sim) to see what kind of variation the model is producing.

set.seed(2024)
s_sim_draws <- draws[, grep('s_sim', colnames(draws))]
s_sim_means <- mean(apply(s_sim_draws, 2, mean))

p <- ggplot() +
    coord_cartesian(xlim = c(0, 1)) +
    geom_vline(data = data.frame(xintercept = s_sim_means), mapping = aes(xintercept=xintercept)) +
    labs(x = "Score", y = "Density") +
    scale_y_continuous(labels = scales::percent_format(accuracy = 1)) +
    theme_bw() +  # Set background to white
    theme(
        panel.grid.major = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
    )

for (row in sample(1:4000, size = 100)) {
    data <- data.frame(s = as.vector(s_sim_draws[row, ]))
    p <- p + geom_density(data = data, mapping = aes(s), colour = alpha('black', 0.1))
}

print(p)

png

What we see here are the results of 100 randomly chosen simulated classes. This seems roughly reasonable to me - the mean is firmly centred around 75% and there’s considerable variation. Looking at this, there is maybe a bit too much variation - I would expect almost all of the weight to be in the range 50% - 100%, and there are a few classes with a significant amount of weight below 50%. On the other hand, it’s probably not the worst thing in the world for this to have too much variation - the data should be able to fix that.

Now let’s actually train the model on some data and see what happens. We’ll have to slightly modify the model we used for our prior predictive check to incorporate the data, but otherwise it will look very similar.

student_score_model_code <- "
data {
    int<lower = 0> N; // number of students
    vector[N] s; // student scores
}
parameters {
    real<lower = 0, upper = 1> mu;
    real<lower = 0> nu;
}
 transformed parameters {
    real<lower = 0> a = mu * nu;
    real<lower = 0> b = (1 - mu) * nu;
}
model {
    s ~ beta(a, b);
    mu ~ beta(12, 4);
    nu ~ normal(10, 1);
}
"

student_score_model_file <- write_stan_file(student_score_model_code)
student_score_model <- cmdstan_model(student_score_model_file)

student_score_model_data <- list(
    N = NUM_STUDENTS,
    s = actual_student_scores
)
student_score_fit <- student_score_model$sample(
    data = student_score_model_data,
    chains = 4
)
student_score_fit$summary()
Running MCMC with 4 sequential chains...

All 4 chains finished successfully.
Mean chain execution time: 0.0 seconds.
Total execution time: 0.6 seconds.
A draws_summary: 5 × 10
variablemeanmediansdmadq5q95rhatess_bulkess_tail
<chr><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl>
lp__7.99087308.31319500.99715640.701217915.9504810 8.93980851.0019981874.2122717.678
mu 0.80193370.80310150.02475880.024139690.7597064 0.84111991.0009582543.7922389.431
nu 9.92676939.91467500.93653350.930798528.417140511.49824001.0004503232.9582479.932
a 7.96251217.96125500.80812380.816504896.6524980 9.30882450.9998263037.9122657.738
b 1.96425711.95284500.29436200.291138161.4906390 2.47211951.0018313197.0772450.331
draws <- as_draws_matrix(student_score_fit)

# mu
ggplot(data.frame(mu = draws[, 'mu'])) +
    geom_density(aes( mu )) +
    geom_vline(aes(xintercept = 0.75)) +
    coord_cartesian(xlim = c(0, 1)) +
    labs(x = "mu", y = "Density") +
    scale_x_continuous(labels = scales::percent_format(accuracy = 1)) +
    theme_bw() +  # Set background to white
    theme(
        panel.grid.major = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
    )

# nu
ggplot(data.frame(nu = draws[, 'nu'])) +
    geom_density(aes(nu)) +
    geom_vline(aes(xintercept = 10)) +
    labs(x = "nu", y = "Density") +
    scale_x_continuous(labels = scales::percent_format(accuracy = 1)) +
    theme_bw() +  # Set background to white
    theme(
        panel.grid.major = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
    )

png

png

Great! So this model seems to be doing a good job of recovering the actual values. Now that we have some confidence in this part of the model, let’s extend it a bit more to consider the situation of different schools, each with their own means, which now influence the scores of their students. We’ll assume that the schools each draw their means from a single beta distribution, and we’ll also assument that each student draws their score from a beta distribution whose mean is the school mean.

set.seed(1234)
NUM_SCHOOLS <- 5
school_mu <- 0.75
school_nu <- 50

school_alpha <- school_mu * school_nu
school_beta <- (1 - school_mu) * school_nu

school_means <- rbeta(NUM_SCHOOLS, school_alpha, school_beta)

ps <- seq(from = 0, to = 1, length.out = 100)
density <- dbeta(ps, school_alpha, school_beta)
ggplot() +
    geom_line(data = data.frame(p = ps, density = density), mapping = aes(p, density)) +
    geom_point(data = data.frame(p = school_means, density = dbeta(school_means, school_alpha, school_beta)), mapping = aes(p, density)) +
    labs(x = "School-Specific Mean", y = "Density") +
    scale_x_continuous(labels = scales::percent_format(accuracy = 1)) +
    theme_bw() +  # Set background to white
    theme(
        panel.grid.major = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
    )

png

Now we’ll simulate some students for each of these schools.

STUDENT_NU <- 30

student_alpha <- school_means * STUDENT_NU
student_beta <- (1 - school_means) * STUDENT_NU

ps <- seq(from = 0, to = 1, length.out = 100)
plot_df <- data.frame(p = numeric(), density = numeric(), school = integer())
for (school_id in 1:NUM_SCHOOLS) {
    alpha <- student_alpha[school_id]
    beta <- student_beta[school_id]
    school_data <- data.frame(
        p = ps,
        density = dbeta(ps, alpha, beta),
        school = school_id
    )
    plot_df <- rbind(plot_df, school_data)
}
plot_df$school <- as.factor(plot_df$school)

ggplot(plot_df, aes(p, density, group = school, colour = school)) +
    geom_line() +
    labs(x = "p", y = 'Density', colour = "School") +
    scale_x_continuous(labels = scales::percent_format(accuracy = 1)) +
    theme_bw() +  # Set background to white
    theme(
        panel.grid.major = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
    )

png

Here we can see the distribution of scores from which each student will draw their scores.

set.seed(1234)
STUDENTS_PER_SCHOOL <- 20
student_scores <- data.frame(s = numeric(), school = integer())

for (school in 1:NUM_SCHOOLS) {
    alpha <- student_alpha[school]
    beta <- student_beta[school]
    school_data <- data.frame(
        s = rbeta(STUDENTS_PER_SCHOOL, alpha, beta),
        school = school
    )
    student_scores <- rbind(student_scores, school_data)
}
student_scores$school <- as.factor(student_scores$school)

ggplot(student_scores, aes(s, colour = school, group = school)) +
    geom_density() +
    coord_cartesian(xlim = c(0, 1)) +
    labs(x = 'Score', y = 'Density', colour = "School") +
    scale_x_continuous(labels = scales::percent_format(accuracy = 1)) +
    theme_bw() +
    theme(
        panel.grid.major = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
    )

png

plot_df <- data.frame(true_mean = numeric(), empirical_mean = numeric())
for (school_id in 1:NUM_SCHOOLS) {
    school_data <- student_scores[student_scores$school == school_id, ]
    plot_df <- rbind(plot_df, data.frame(true_mean = school_means[school_id], empirical_mean = mean(school_data$s)))
}

empirical_school_means <- plot_df$empirical_mean

ggplot(plot_df, aes(true_mean, empirical_mean)) +
    geom_point() +
    geom_abline(intercept = 0, slope = 1, linetype = 'dashed') +
    labs(x = "True Mean", y = "Empirical Mean") +
    scale_x_continuous(labels = scales::percent_format(accuracy = 1)) +
    theme_bw() +
    theme(
        panel.grid.major = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
    )

png

First, let’s run an unpooled model - that is, one where we estimate each school’s values separately. That model looks like

sBeta(α[school],β[school])α[school]=μ[school]ν[school]β[school]=(1μ[school])ν[school]μ[school]Beta(3,1)ν[school]Normal(10,1)\begin{align*} \text{s} &\sim \text{Beta}(\alpha[\text{school}], \beta[\text{school}]) \\ \alpha[\text{school}] &= \mu[\text{school}] * \nu[\text{school}] \\ \beta[\text{school}] &= (1 - \mu[\text{school}]) * \nu[\text{school}] \\ \mu[\text{school}] &\sim \text{Beta}(3, 1) \\ \nu[\text{school}] &\sim \text{Normal}(10, 1) \\ \end{align*}
student_school_unpooled_model_code <- "
data {
    int<lower = 0> N_students; // number of students
    int<lower = 0> N_schools; // number of schools
    vector[N_students] s; // student scores
    array[N_students] int school_id; // the school id for each student
}
parameters {
    array[N_schools] real<lower = 0, upper = 1> mu_school;
    array[N_schools] real<lower = 0> nu_school;
}
 transformed parameters {
    array[N_schools] real<lower = 0> a_school;
    array[N_schools] real<lower = 0> b_school;
    for (i in 1:N_schools) {
        a_school[i] = mu_school[i] * nu_school[i];
        b_school[i] = (1 - mu_school[i]) * nu_school[i];
    }
}
model {
    for (student_index in 1:N_students) {
        s[student_index] ~ beta(a_school[school_id[student_index]], b_school[school_id[student_index]]);
    }
    mu_school ~ beta(3, 1);
    nu_school ~ normal(30, 4);
}
"

student_school_unpooled_model_file <- write_stan_file(student_school_unpooled_model_code)
student_school_unpooled_model <- cmdstan_model(student_school_unpooled_model_file)

student_school_model_data <- list(
    N_students = nrow(student_scores),
    N_schools = NUM_SCHOOLS,
    s = student_scores$s,
    school_id = as.integer(student_scores$school)
)
student_school_unpooled_fit <- student_school_unpooled_model$sample(
    data = student_school_model_data,
    chains = 4
)
student_school_unpooled_fit$summary()
Running MCMC with 4 sequential chains...

All 4 chains finished successfully.
Mean chain execution time: 0.2 seconds.
Total execution time: 0.9 seconds.
A draws_summary: 21 × 10
variablemeanmediansdmadq5q95rhatess_bulkess_tail
<chr><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl>
lp__ 133.2063127133.53300002.285598352.11715280128.9039000136.27620001.00054061383.8652676.246
mu_school[1] 0.8466306 0.84678650.014513410.01407951 0.8222164 0.87005901.00129899296.8962773.345
mu_school[2] 0.7631883 0.76331600.016665340.01648503 0.7358237 0.78974571.00008167172.7132855.873
mu_school[3] 0.6985636 0.69846150.019032180.01865037 0.6669859 0.72963871.00270917569.5572383.093
mu_school[4] 0.8986209 0.89877150.011102090.01102758 0.8803896 0.91616871.00062626881.0452579.987
mu_school[5] 0.7082870 0.70888450.018381960.01870893 0.6780569 0.73779801.00023428158.0052954.655
nu_school[1] 30.5954286 30.53845003.603815063.63333369 24.7442450 36.64045001.00112416743.8322632.137
nu_school[2] 30.3578492 30.30700003.713837493.70813086 24.3044150 36.68592501.00222978089.1652602.090
nu_school[3] 28.9024282 28.89475003.592288133.67084347 22.9660100 34.85315001.00147817393.0742979.367
nu_school[4] 32.0602813 32.07185003.673262793.72807183 26.0406500 38.05463501.00189237189.6902859.561
nu_school[5] 30.7457093 30.68255003.783885833.78374346 24.5528400 37.04048001.00049067702.5082870.507
a_school[1] 25.9068463 25.85345003.115062173.11612868 20.8941350 31.20081001.00089236732.6952500.360
a_school[2] 23.1730379 23.13265002.912518172.86645884 18.5673600 28.09460001.00280867867.2792794.205
a_school[3] 20.1909199 20.20930002.573302882.56801146 15.9950150 24.49774501.00104907449.5852740.689
a_school[4] 28.8145096 28.85170003.357741043.36780003 23.3074450 34.27757501.00150057155.1982752.986
a_school[5] 21.7806440 21.69245002.769443872.74147566 17.2344200 26.50643001.00129767711.5073129.558
b_school[1] 4.6885823 4.66927000.681874570.69292276 3.5813565 5.84157350.99997068123.7412632.944
b_school[2] 7.1848123 7.15999500.981844171.00975438 5.6196490 8.82478151.00134348428.7922801.349
b_school[3] 8.7115083 8.69489501.208131891.27687442 6.7770070 10.72263501.00122117377.4692786.540
b_school[4] 3.2457709 3.23095000.484462460.49350565 2.4725655 4.06225451.00087317247.9932369.960
b_school[5] 8.9650650 8.95289501.209693151.20415289 6.9710085 11.02671501.00125077371.6082881.813
unpooled_student_school_draws <- as_draws_matrix(student_school_unpooled_fit$draws())
# school means
unpooled_school_means_draws <- unpooled_student_school_draws[, grep('mu_school', colnames(unpooled_student_school_draws))]
unpooled_school_mean_means <- apply(unpooled_school_means_draws, 2, mean)
unpooled_school_mean_lower <- apply(unpooled_school_means_draws, 2, function(col) quantile(col, 0.025))
unpooled_school_mean_upper <- apply(unpooled_school_means_draws, 2, function(col) quantile(col, 0.975))

unpooled_means_plot_df <- data.frame(school = 1:NUM_SCHOOLS, mean = unpooled_school_mean_means, lower = unpooled_school_mean_lower, upper = unpooled_school_mean_upper)

unpooled_school_means_plot <- ggplot() +
    geom_errorbar(data = unpooled_means_plot_df, mapping = aes(school, ymin = lower, ymax = upper)) +
    geom_point(data = unpooled_means_plot_df, aes(school, mean)) +
    geom_point(data = data.frame(school = 1:NUM_SCHOOLS, mean = school_means), mapping = aes(school, mean, colour = "True Mean")) +
    geom_point(data = data.frame(school = 1:NUM_SCHOOLS, mean = empirical_school_means), mapping = aes(school, mean, colour = 'Empirical Mean')) +
    labs(x = "School", y = "Mean Score", colour = "Type") +
    scale_y_continuous(labels = scales::percent_format(accuracy = 1)) +
    theme_bw() +
    theme(
        panel.grid.major = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
    )
print(unpooled_school_means_plot)

png

This looks pretty good! We’re doing a good job of recovering these values. Now let’s share information among the schools by using a hierarchical model.

We’ll use the following model to recover the values.

sBeta(α[school],β[school])α[school]=μ[school]ν[school]β[school]=(1μ[school])ν[school]μ[school]Beta(αˉ,βˉ)ν[school]Normal(νˉ,σν)αˉNormal(3,0.5)βˉNormal(1,0.3)νˉNormal(30,5)σνExponential(1)\begin{align*} \text{s} &\sim \text{Beta}(\alpha[\text{school}], \beta[\text{school}]) \\ \alpha[\text{school}] &= \mu[\text{school}] * \nu[\text{school}] \\ \beta[\text{school}] &= (1 - \mu[\text{school}]) * \nu[\text{school}] \\ \mu[\text{school}] &\sim \text{Beta}(\bar\alpha, \bar\beta) \\ \nu[\text{school}] &\sim \text{Normal}(\bar\nu, \sigma_\nu) \\ \bar\alpha &\sim \text{Normal}(3, 0.5) \\ \bar\beta &\sim \text{Normal}(1, 0.3) \\ \bar\nu &\sim \text{Normal}(30, 5) \\ \sigma_\nu &\sim \text{Exponential}(1) \\ \end{align*}

student_school_model_code <- "
data {
    int<lower = 0> N_students; // number of students
    int<lower = 0> N_schools; // number of schools
    vector[N_students] s; // student scores
    array[N_students] int school_id; // the school id for each student
}
parameters {
    array[N_schools] real<lower = 0, upper = 1> mu_school;
    array[N_schools] real<lower = 0> nu_school_raw; // non-centred parameterization; implies nu_school ~ normal(nu_bar, sigma_bar)
    real<lower = 0> alpha_bar;
    real<lower = 0> beta_bar;
    real<lower = 0> nu_bar;
    real<lower = 0> sigma_nu;
}
 transformed parameters {
    array[N_schools] real<lower = 0> nu_school;
    array[N_schools] real<lower = 0> a_school;
    array[N_schools] real<lower = 0> b_school;
    for (i in 1:N_schools) {
        nu_school[i] = nu_bar + sigma_nu * nu_school_raw[i];
        a_school[i] = mu_school[i] * nu_school[i];
        b_school[i] = (1 - mu_school[i]) * nu_school[i];
    }
}
model {
    for (student_index in 1:N_students) {
        s[student_index] ~ beta(a_school[school_id[student_index]], b_school[school_id[student_index]]);
    }
    mu_school ~ beta(alpha_bar, beta_bar);
    nu_school_raw ~ std_normal();
    alpha_bar ~ normal(3, 0.5);
    beta_bar ~ normal(1, 0.3);
    nu_bar ~ normal(30, 5);
    sigma_nu ~ exponential(1);
}
"

student_school_model_file <- write_stan_file(student_school_model_code)
student_school_model <- cmdstan_model(student_school_model_file)

student_school_fit <- student_school_model$sample(
    data = student_school_model_data,
    chains = 4,
    iter_warmup = 1000,
    adapt_delta = 0.95
)
student_school_fit$summary()
Running MCMC with 4 sequential chains...

All 4 chains finished successfully.
Mean chain execution time: 0.4 seconds.
Total execution time: 1.9 seconds.
A draws_summary: 30 × 10
variablemeanmediansdmadq5q95rhatess_bulkess_tail
<chr><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl>
lp__ 120.2997945120.63950002.889450942.76875550114.91420000124.41500001.00486821490.8392390.733
mu_school[1] 0.8470211 0.84720850.013585580.01298165 0.82471955 0.86882711.00105325050.1642470.032
mu_school[2] 0.7638199 0.76393850.015941880.01603803 0.73703180 0.79001031.00106245844.6733214.262
mu_school[3] 0.6989982 0.69911700.017585730.01766592 0.67012455 0.72753201.00174145188.3703085.584
mu_school[4] 0.8983489 0.89854850.011026250.01115063 0.88029595 0.91595381.00255735146.1753025.248
mu_school[5] 0.7082870 0.70836600.017163130.01705657 0.67989185 0.73689710.99972095376.6213100.150
nu_school_raw[1] 0.8128569 0.69266500.611021280.60914400 0.07260510 1.99615651.00049953516.6251721.888
nu_school_raw[2] 0.7895747 0.67232200.601659210.59760344 0.05733236 1.96603351.00057973220.2421853.584
nu_school_raw[3] 0.7553973 0.63242100.587970660.56209072 0.05570231 1.90029251.00046862797.9381643.460
nu_school_raw[4] 0.8333840 0.71061000.621894820.62426430 0.05943881 2.02473300.99984612781.6311139.225
nu_school_raw[5] 0.8048431 0.67865300.611650420.60419731 0.06066773 2.01738901.00092782768.5461470.845
alpha_bar 3.1320128 3.13371500.467857860.47165954 2.35937800 3.89708551.00191465119.3522751.850
beta_bar 1.0907279 1.08443500.244641340.24603302 0.69713060 1.50439751.00093135496.9582344.933
nu_bar 31.5787906 31.55825003.379200853.41346411 26.21277500 37.16703001.00035544995.0013090.744
sigma_nu 1.0272002 0.71432801.020226470.73669653 0.05808706 3.06015651.00025204038.9932059.170
nu_school[1] 32.4157204 32.32715003.437420953.45297540 26.91420000 38.13623501.00093894991.6263188.874
nu_school[2] 32.3700213 32.29140003.437061763.44141112 26.87087000 38.05872501.00062165166.6213033.495
nu_school[3] 32.3245989 32.27645003.421721883.45601473 26.84857000 38.01147501.00059465223.5803131.844
nu_school[4] 32.4816132 32.42465003.474546233.44066982 26.96652500 38.33147001.00079825070.8963205.803
nu_school[5] 32.3932509 32.27585003.438450393.44600718 26.99057000 38.14771001.00070195338.4953007.364
a_school[1] 27.4618737 27.36815002.991224022.98136034 22.71480500 32.38468501.00053174961.9193132.647
a_school[2] 24.7285854 24.65400002.707889082.68691598 20.37299000 29.28116501.00022785146.8043083.624
a_school[3] 22.5988739 22.55535002.492583922.49840339 18.63084000 26.81506501.00014005127.8573115.493
a_school[4] 29.1838142 29.08715003.178592643.14333439 24.18315500 34.48423001.00108084943.2753131.463
a_school[5] 22.9453409 22.86835002.514327912.52901908 19.04795500 27.24671501.00010325396.1153104.879
b_school[1] 4.9538464 4.93100000.646191370.64508667 3.93117700 6.05259201.00191775480.9923080.917
b_school[2] 7.6414370 7.61741500.931657460.91027192 6.15480500 9.22507201.00120275484.0393014.747
b_school[3] 9.7257252 9.67617001.144358711.16653192 7.94368300 11.69197501.00188705549.2592891.207
b_school[4] 3.2977995 3.27010500.474900240.46661129 2.53166850 4.10716701.00089535902.1712746.835
b_school[5] 9.4479103 9.40359001.130838641.13598295 7.68017550 11.35315501.00058815082.9523188.104
student_school_draws <- as_draws_matrix(student_school_fit$draws())
# # school means
school_means_draws <- student_school_draws[, grep('mu_school', colnames(student_school_draws))]
school_mean_means <- apply(school_means_draws, 2, mean)
school_mean_lower <- apply(school_means_draws, 2, function(col) quantile(col, 0.025))
school_mean_upper <- apply(school_means_draws, 2, function(col) quantile(col, 0.975))

means_plot_df <- data.frame(school = 1:NUM_SCHOOLS, mean = school_mean_means, lower = school_mean_lower, upper = school_mean_upper)

prior_mean <- mean(student_school_draws[, 'alpha_bar'] / (student_school_draws[, 'alpha_bar'] + student_school_draws[, 'beta_bar']))

pooled_school_means_plot <- ggplot() +
    geom_errorbar(data = means_plot_df, mapping = aes(school, ymin = lower, ymax = upper)) +
    geom_point(data = means_plot_df, aes(school, mean)) +
    geom_point(data = data.frame(school = 1:NUM_SCHOOLS, mean = school_means), mapping = aes(school, mean, colour = 'True Mean')) +
    geom_point(data = data.frame(school = 1:NUM_SCHOOLS, mean = unpooled_school_mean_means), aes(school, mean, colour = "Unpooled Mean")) +
    geom_hline(mapping = aes(yintercept = prior_mean), linetype = 'dashed') +
    labs(x = "School", y = "Value", colour = "Mean Type") +
    scale_y_continuous(labels = scales::percent_format(accuracy = 1)) +
    theme_bw() +
    theme(
        panel.grid.major = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
    )
print(pooled_school_means_plot)

png

So as we can see, pooling didn’t have too much of an effect; there’s a minimal change from the pooled mean (black) and the unpooled mean (blue).

Now we’re ready to tackle a more complex hierarchical model with multiple layers! To start off, let’s say we have some districts. Each district will have some schools, and each school will have some students. Each of the students will have a score that is the school mean plus some standard deviation, drawn from a beta distribution. In turn, each of the school means for a district will be drawn from the same distribution. Finally, each of the district means will be drawn from a different shared distribution.

set.seed(1234)

NUM_DISTRICTS <- 3
SCHOOLS_PER_DISTRICT <- 3
TOTAL_SCHOOLS <- NUM_DISTRICTS * SCHOOLS_PER_DISTRICT
STUDENTS_PER_SCHOOL <- 20
TOTAL_STUDENTS <- TOTAL_SCHOOLS * STUDENTS_PER_SCHOOL

district_mu <- 0.75
district_nu <- 30
district_alpha <- district_mu * district_nu
district_beta <- (1 - district_mu) * district_nu

district_means <- rbeta(NUM_DISTRICTS, district_alpha, district_beta)
ps <- seq(from = 0, to = 1, length.out = 100)
district_distribution_df <- data.frame(
    p = ps,
    density = dbeta(ps, district_alpha, district_beta)
)
districts_df <- data.frame(
    p = district_means,
    density = dbeta(district_means, district_alpha, district_beta)
)
ggplot(NULL, aes(p, density)) +
    geom_line(data = district_distribution_df) +
    geom_point(data = districts_df) +
    labs(x = "p", y = "Density") +
    scale_x_continuous(labels = scales::percent_format(accuracy = 1)) +
    theme_bw() +
    theme(
        panel.grid.major = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
    )

png

So now we have the district means coming from a single distribution. Within that, each district will have schools, each of whose means comes from a distribution whose mean is the district mean!

set.seed(2024)
ps <- seq(from = 0, to = 1, length.out = 1e3)
school_params <- data.frame(district = integer(), school_id = integer(), school_mu = numeric())
district_nu <- 40
for (district_id in 1:NUM_DISTRICTS) {
    district_mu <- district_means[district_id]
    district_alpha <- district_mu * district_nu
    district_beta <- (1 - district_mu) * district_nu
    school_data <- data.frame(
        district = rep(district_id, SCHOOLS_PER_DISTRICT),
        school_id = district_id + 1:SCHOOLS_PER_DISTRICT,
        school_mu = rbeta(SCHOOLS_PER_DISTRICT, district_alpha, district_beta)
    )
    school_params <- rbind(school_params, school_data)

    district_distribution_df <- data.frame(p = ps, density = dbeta(ps, district_alpha, district_beta))
    school_means_df <- data.frame(p = school_data$school_mu, density = dbeta(school_data$school_mu, district_alpha, district_beta))
    school_plot <- ggplot(NULL, aes(p, density)) +
        geom_line(data = district_distribution_df) +
        geom_point(data = school_means_df) +
        labs(x = "p", y = "Density", title = paste("Distribution and school means for district ", district_id)) +
        scale_x_continuous(labels = scales::percent_format(accuracy = 1)) +
        theme_bw() +
        theme(
            panel.grid.major = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
        )
    print(school_plot)
}

png

png

png

Now that we have the school information, we can generate a simulated group of students for each school. Each student will draw their score from a distribution whose mean is the school mean.

set.seed(2024)
ps <- seq(from = 0, to = 1, length.out = 1e3)
student_data <- data.frame(district_id = integer(), school_id = integer(), score = numeric())
student_nu <- 30
for (school_id in 1:TOTAL_SCHOOLS) {
    district_id <- school_params[school_id, 'district']
    school_mu <- school_params[school_id, 'school_mu']
    student_alpha <- school_mu * student_nu
    student_beta <- (1 - school_mu) * student_nu
    school_scores <- data.frame(
        district_id = rep(district_id, STUDENTS_PER_SCHOOL),
        school_id <- rep(school_id, STUDENTS_PER_SCHOOL),
        score = rbeta(STUDENTS_PER_SCHOOL, student_alpha, student_beta)
    )
    student_data <- rbind(student_data, school_scores)
}
student_data$student_id <- as.factor(seq_len(nrow(student_data)))
student_data$school_id <- as.factor(student_data$school_id)
student_data$district_id <- as.factor(student_data$district_id)

student_plot <- ggplot(student_data, aes(student_id, score)) +
    geom_point(aes(colour = school_id, shape = district_id)) +
    labs(x = "Student", y = "Score", colour = "School", shape = "District") +
    scale_y_continuous(labels = scales::percent_format(accuracy = 1)) +
    theme_bw() +  # Set background to white
    theme(
        panel.grid.major.x = element_blank(),
        panel.grid.minor.x = element_blank(),
        panel.grid.major.y = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
        axis.ticks.x = element_blank(),
        axis.text.x = element_blank(),
    )
for (district_id in 1:(NUM_DISTRICTS - 1)) {
    print(district_id)
    student_plot <- student_plot +
        geom_vline(xintercept = SCHOOLS_PER_DISTRICT * STUDENTS_PER_SCHOOL * district_id + 0.5)
}
print(student_plot)
[1] 1
[1] 2

png

Great! Now that we have the data and a good handle on the generation process, let’s look what this looks like as a model. Again, there are a few levels here. There will be a shared district-level set of parameters as well as a shared school-level set of parameters influenced by the district. The student will in turn have their own parameters, again inheriting from the school parameters. There’s a lot here, but let’s see what it looks like in Stan!


# NB: this is the non-centred parameterization to avoid divergent transitions
student_school_district_model_code <- "
data {
    int<lower = 0> N_students; // number of students
    int<lower = 0> N_schools; // number of schools
    int<lower = 0> N_districts; // number of districts
    vector[N_students] s; // student scores
    array[N_students] int school_id; // the school id for each student
    array[N_schools] int district_id; // the district for each school
}
parameters {
    array[N_districts] real<lower = 0, upper = 1> mu_district;
    array[N_districts] real<lower = 0> nu_district_raw; // standard normal -> implies nu_school ~ normal(nu_bar, sigma_nu)
    array[N_schools] real<lower = 0, upper = 1> mu_school;
    array[N_schools] real<lower = 0> nu_school;
    real<lower = 0> alpha_bar;
    real<lower = 0> beta_bar;
    real<lower = 0> nu_bar;
    real<lower = 0> sigma_nu;
}
 transformed parameters {
    // array[N_schools] real<lower = 0> nu_school;
    array[N_schools] real<lower = 0> a_school;
    array[N_schools] real<lower = 0> b_school;
    array[N_districts] real<lower = 0> a_district;
    array[N_districts] real <lower = 0> b_district;
    array[N_districts] real <lower = 0> nu_district;
    for (i in 1:N_schools) {
        // nu_school[i] = nu_bar + sigma_nu * nu_school_raw[i];
        a_school[i] = mu_school[i] * nu_school[i];
        b_school[i] = (1 - mu_school[i]) * nu_school[i];
    }
    for (district_index in 1:N_districts) {
        nu_district[district_index] = nu_bar + sigma_nu * nu_district_raw[district_index];
        a_district[district_index] = mu_district[district_index] * nu_district[district_index];
        b_district[district_index] = (1 - mu_district[district_index]) * nu_district[district_index];
    }
}
model {
    for (student_index in 1:N_students) {
        s[student_index] ~ beta(a_school[school_id[student_index]], b_school[school_id[student_index]]);
    }
    for (school_index in 1:N_schools) {
        mu_school[school_index] ~ beta(a_district[district_id[school_index]], b_district[district_id[school_index]]);
    }
    mu_district ~ beta(alpha_bar, beta_bar);
    nu_district_raw ~ std_normal();

    // hyperpriors
    alpha_bar ~ normal(3, 0.5);
    beta_bar ~ normal(1, 0.3);
    nu_bar ~ normal(40, 5);
    sigma_nu ~ exponential(1);
}
"

student_school_district_model_file <- write_stan_file(student_school_district_model_code)
student_school_district_model <- cmdstan_model(student_school_district_model_file)

student_school_district_model_data <- list(
    N_students = nrow(student_data),
    N_schools = length(unique(student_data$school_id)),
    N_districts = length(unique(student_data$district_id)),
    s = student_data$score,
    school_id = student_data$school_id,
    district_id = school_params$district
)

student_school_district_fit <- student_school_district_model$sample(
    data = student_school_district_model_data,
    chains = 4,
    iter_warmup = 1000,
    adapt_delta = 0.95
)
student_school_district_fit$summary()
Running MCMC with 4 sequential chains...

All 4 chains finished successfully.
Mean chain execution time: 0.9 seconds.
Total execution time: 3.9 seconds.
A draws_summary: 56 × 10
variablemeanmediansdmadq5q95rhatess_bulkess_tail
<chr><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl>
lp__ 222.7176947223.0595000 3.93686660 3.79693860215.76725000228.53535001.00353011654.2612357.782
mu_district[1] 0.8108640 0.8118355 0.03423787 0.03307384 0.75313205 0.86675751.00055945053.4852857.642
mu_district[2] 0.6840614 0.6850335 0.04265020 0.04337569 0.61480330 0.75288401.00170763930.2152879.500
mu_district[3] 0.6941475 0.6959630 0.04198568 0.04179153 0.62348970 0.76136621.00084985018.5372966.480
nu_district_raw[1] 0.8080386 0.6769430 0.60914963 0.60004825 0.06027230 1.98029300.99993843687.7491631.689
nu_district_raw[2] 0.7951544 0.6724660 0.59759151 0.58762851 0.06592412 1.92476651.00150422791.6021709.045
nu_district_raw[3] 0.8103433 0.6857005 0.60298377 0.59143509 0.06953199 1.96855801.00011913552.6711907.645
mu_school[1] 0.7745501 0.7750630 0.02028668 0.01939241 0.73972350 0.80622641.00159664542.3432547.013
mu_school[2] 0.8342911 0.8349070 0.01277099 0.01272219 0.81273770 0.85455901.00119194220.7553036.484
mu_school[3] 0.8395642 0.8404945 0.01672362 0.01640942 0.81096990 0.86564311.00100394468.8032866.154
mu_school[4] 0.7244529 0.7250115 0.01776163 0.01737237 0.69428420 0.75296401.00097743779.5052623.897
mu_school[5] 0.6209195 0.6212540 0.02417814 0.02367193 0.58176280 0.65991211.00066654574.1962910.079
mu_school[6] 0.7070876 0.7075630 0.01730462 0.01654804 0.67776680 0.73443440.99977404020.6862711.795
mu_school[7] 0.6434111 0.6434790 0.01701471 0.01667184 0.61544455 0.67153781.00113854227.7452632.624
mu_school[8] 0.7353718 0.7358175 0.01382481 0.01381042 0.71246380 0.75751971.00120513897.4452490.404
mu_school[9] 0.7056546 0.7066625 0.01718360 0.01605433 0.67772540 0.73245041.00047524186.5852553.651
nu_school[1] 19.2145683 18.5469500 5.71679088 5.46189840 11.00878000 29.49995500.99970074618.7972857.748
nu_school[2] 43.2529643 41.976950013.3299369612.86148087 24.15926000 67.65738001.00016044144.8012702.517
nu_school[3] 25.7384937 24.9508000 7.72034944 7.77319767 14.41916500 39.54452501.00043493992.2733020.537
nu_school[4] 30.0836856 29.1282000 9.12058981 8.83051386 16.85719500 46.44703501.00026844392.1582939.064
nu_school[5] 18.6909772 18.1623000 5.52931243 5.42742795 10.69125500 28.44367001.00037934818.2633010.402
nu_school[6] 34.3268246 33.385300010.33198302 9.93253044 19.77090500 52.85460501.00011684400.4252784.537
nu_school[7] 38.6659344 37.697400011.6955623311.55316050 21.88963000 59.57346501.00249605346.6532871.480
nu_school[8] 55.3977254 53.898300017.1978083416.83129063 30.63416500 86.06473001.00060234799.0143115.822
nu_school[9] 36.4689579 35.210650011.4358231411.03187834 20.13321000 57.44540501.00017424279.7182810.782
alpha_bar 3.0422431 3.0382050 0.48231749 0.47995469 2.24597800 3.84598000.99999785304.7402262.935
beta_bar 1.1143574 1.1040950 0.25961945 0.27084730 0.70305225 1.55945701.00247794861.5932044.442
nu_bar 40.9648167 40.9428500 4.80195340 4.84209747 33.12208000 48.91255001.00064404662.2142713.798
sigma_nu 1.0194074 0.7193850 0.99442673 0.73303228 0.06086981 2.96163801.00055273729.0721817.774
a_school[1] 14.8995139 14.3299500 4.50320965 4.33193481 8.48290100 22.93802500.99969854538.2652901.148
a_school[2] 36.1175329 34.956450011.2395120910.78080003 20.09869000 56.43630001.00018984065.8832650.351
a_school[3] 21.6456836 20.9735000 6.61668905 6.61232187 11.94504000 33.47538001.00042153870.3082884.144
a_school[4] 21.8187776 21.1334500 6.70637210 6.41261565 12.10184500 33.79432501.00039354254.7812893.809
a_school[5] 11.6109026 11.2779000 3.47769899 3.46460640 6.62352250 17.82840000.99979814838.9693135.722
a_school[6] 24.2936537 23.5578000 7.40209612 6.96792348 13.78885000 37.65217001.00028184319.0452782.099
a_school[7] 24.8864011 24.2969500 7.58076818 7.54791660 13.91935000 38.36834501.00212025219.6172808.585
a_school[8] 40.7653610 39.668750012.7623589712.46696101 22.32013000 63.86176001.00045924707.3212921.420
a_school[9] 25.7575809 24.7504000 8.16618410 7.81196766 14.07162000 40.84604001.00025534164.5092748.807
b_school[1] 4.3150550 4.1789750 1.28389551 1.27316051 2.49373000 6.64483951.00002535010.9652645.208
b_school[2] 7.1354307 6.8631350 2.16551649 2.07649991 4.02084000 11.19836501.00015764643.7352639.532
b_school[3] 4.0928110 3.9709350 1.18284069 1.13933362 2.36425450 6.20350201.00038134963.0402705.370
b_school[4] 8.2649078 8.0173400 2.48574372 2.36380555 4.64036200 12.74249001.00003744801.2472700.400
b_school[5] 7.0800751 6.8545050 2.12275121 2.09103680 4.03383650 10.85819501.00162654784.0773056.471
b_school[6] 10.0331710 9.7383350 3.00523071 2.88286381 5.76717100 15.40929501.00006394591.2382680.650
b_school[7] 13.7795324 13.4024500 4.18863892 4.18923456 7.75234100 21.29824501.00299065496.5402827.717
b_school[8] 14.6323653 14.2183500 4.51502754 4.48894215 8.09328700 22.71038001.00093635022.4222875.640
b_school[9] 10.7113761 10.3030500 3.34226299 3.20847983 5.89056500 16.81234501.00029224603.4642897.899
a_district[1] 33.8885185 33.7491500 4.27652399 4.24512858 26.98027000 41.10862001.00086704638.1842703.454
a_district[2] 28.5749109 28.5183500 3.79763607 3.87714726 22.42455500 34.93349501.00127014119.9792951.821
a_district[3] 29.0074459 28.9348000 3.83058313 3.87269946 22.74620500 35.41817501.00126034829.5722462.030
b_district[1] 7.8981206 7.8087650 1.67861747 1.71713249 5.30965150 10.81585501.00271084998.2562774.579
b_district[2] 13.1974426 13.1079500 2.35517319 2.32227051 9.38773450 17.23962501.00009544930.2572554.079
b_district[3] 12.7824112 12.7071500 2.30107724 2.34206322 9.12928300 16.73308501.00046434764.2352811.945
nu_district[1] 41.7866384 41.6816000 4.91358096 4.98472359 33.71537000 50.03082501.00083374875.1052815.289
nu_district[2] 41.7723534 41.7305500 4.91772353 4.92860718 33.74658000 49.87164501.00029814693.6672654.541
nu_district[3] 41.7898568 41.7380500 4.91824700 4.95959352 33.68915500 49.86680501.00085884769.3462773.605
student_school_district_draws <- as_draws_matrix(student_school_district_fit)
# Extract district means
district_means_draws <- student_school_district_draws[, grep('mu_district', colnames(student_school_district_draws))]

# Calculate the mean, lower, and upper quantiles for each district
district_mean_means <- apply(district_means_draws, 2, mean)
district_mean_lower <- apply(district_means_draws, 2, function(col) quantile(col, 0.025))
district_mean_upper <- apply(district_means_draws, 2, function(col) quantile(col, 0.975))

# Calculate the mean for the districts
prior_district_mean <- mean(student_school_district_draws[, 'alpha_bar'] / (student_school_district_draws[, 'alpha_bar'] + student_school_district_draws[, 'beta_bar']))

# Create a data frame to store the results
district_means_df <- data.frame(
    district = 1:NUM_DISTRICTS,
    mean = district_mean_means,
    lower = district_mean_lower,
    upper = district_mean_upper
)
# Plot the district means along with the true means
ggplot(district_means_df, aes(district, mean)) +
    geom_point(aes(colour = "Estimated Mean")) +
    geom_errorbar(aes(ymin = lower, ymax = upper)) +
    geom_point(data = data.frame(district = 1:NUM_DISTRICTS, mean = district_means), aes(district, mean, colour = "True Mean")) +
    geom_hline(yintercept = prior_district_mean, linetype = 'dashed') +
    labs(x = "District", y = "Mean Score", colour = "Mean Type") +
    scale_y_continuous(labels = scales::percent_format(accuracy = 1)) +
    theme_bw() +
    theme(
        panel.grid.major = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
    )

png

# Extract school means
school_means_draws <- student_school_district_draws[, grep('mu_school', colnames(student_school_district_draws))]

# Calculate the mean, lower, and upper quantiles for each school
school_mean_means <- apply(school_means_draws, 2, mean)
school_mean_lower <- apply(school_means_draws, 2, function(col) quantile(col, 0.025))
school_mean_upper <- apply(school_means_draws, 2, function(col) quantile(col, 0.975))

district_mean_draws <- student_school_district_draws[, grep('mu_district', colnames(student_school_district_draws))]
district_mean_means <- apply(district_means_draws, 2, mean)

# Create a data frame to store the results
school_means_df <- data.frame(
    school = as.factor(1:TOTAL_SCHOOLS),
    mean = school_mean_means,
    lower = school_mean_lower,
    upper = school_mean_upper
)

# Plot the school means along with the true means
plot <- ggplot(school_means_df, aes(school, mean)) +
    geom_point(aes(colour = "Estimated Mean")) +
    geom_errorbar(aes(ymin = lower, ymax = upper)) +
    geom_point(data = data.frame(school = 1:TOTAL_SCHOOLS, mean = school_params$school_mu), aes(school, mean, colour = "True Mean")) +
    labs(x = "School", y = "Mean Score", colour = "Mean Type") +
    scale_y_continuous(labels = scales::percent_format(accuracy = 1)) +
    theme_bw() +
    theme(
        panel.grid.major = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
    )

for (district in 1:NUM_DISTRICTS) {
    print(district)
    plot <- plot +
        geom_segment(x = (district - 1) * SCHOOLS_PER_DISTRICT + 0.5, xend = district * SCHOOLS_PER_DISTRICT + 0.5, y = district_mean_means[district], yend = district_mean_means[district], linetype = "dashed")
}

print(plot)
[1] 1
[1] 2
[1] 3

png

Great! So that worked out pretty well - we’re getting good estimates for both our district-level and school-level values.

So now that we’ve taken a look at a couple of simulated examples, let’s look at some real data: provincial achievement test scores for Alberta!

Example - Alberta Diploma Exam Scores, 2016

Each year, students in Grade 12 in Alberta write diploma exams. These are standardized exams designed to assess the students’ knowledge of different courses. Not every course has a diploma exam, but for those that do, it counts for 30% of their final grade.

Luckily for us, the results of the diploma exam are released publicly, broken up by school, school authority, year, and subject! We’re going to apply the same idea as above to some real data.

The first thing that we need to do is to load in the data. It can be found publicly at the Canada Open Government site. For this, we’ll be looking at the data for the 2015 / 2016 school year.

school_df <- read_excel('data/diploma-multiyear-sch-list-annual.xlsx')
colnames(school_df)
  1. 'Diploma Course'
  2. 'Authority Type'
  3. 'Authority Code'
  4. 'Authority Name'
  5. 'School Code'
  6. 'School Name'
  7. '2012SchStudents Writing'
  8. '2012Sch School Mark % Exc'
  9. '2012Sch School Mark % Acc'
  10. '2012Sch School Average %'
  11. '2012Sch School Standard Deviation %'
  12. '2012Sch Exam Mark % Exc'
  13. '2012Sch Exam Mark Exc Sig'
  14. '2012Sch Exam Mark % Acc'
  15. '2012Sch Exam Mark Acc Sig'
  16. '2012Sch Exam Average %'
  17. '2012Sch Exam Standard Deviation %'
  18. '2013Sch Students Writing'
  19. '2013Sch School Mark % Exc'
  20. '2013Sch School Mark % Acc'
  21. '2013Sch School Average %'
  22. '2013Sch School Standard Deviation %'
  23. '2013Sch Exam Mark % Exc'
  24. '2013Sch Exam Mark Exc Sig'
  25. '2013Sch Exam Mark % Acc'
  26. '2013Sch Exam Mark Acc Sig'
  27. '2013Sch Exam Average %'
  28. '2013Sch Exam Standard Deviation %'
  29. '2014Sch Students Writing'
  30. '2014Sch School Mark % Exc'
  31. '2014Sch School Mark % Acc'
  32. '2014Sch School Average %'
  33. '2014Sch School Standard Deviation %'
  34. '2014Sch Exam Mark % Exc'
  35. '2014Sch Exam Mark Exc Sig'
  36. '2014Sch Exam Mark % Acc'
  37. '2014Sch Exam Mark Acc Sig'
  38. '2014Sch Exam Average %'
  39. '2014Sch Exam Standard Deviation %'
  40. '2015Sch Students Writing'
  41. '2015Sch School Mark % Exc'
  42. '2015Sch School Mark % Acc'
  43. '2015Sch School Average %'
  44. '2015Sch School Standard Deviation %'
  45. '2015Sch Exam Mark % Exc'
  46. '2015Sch Exam Mark Exc Sig'
  47. '2015Sch Exam Mark % Acc'
  48. '2015Sch Exam Mark Acc Sig'
  49. '2015Sch Exam Average %'
  50. '2015Sch Exam Standard Deviation %'
  51. '2016Sch Students Writing'
  52. '2016Sch School Mark % Exc'
  53. '2016Sch School Mark % Acc'
  54. '2016Sch School Average %'
  55. '2016Sch School Standard Deviation %'
  56. '2016Sch Exam Mark % Exc'
  57. '2016Sch Exam Mark Exc Sig'
  58. '2016Sch Exam Mark % Acc'
  59. '2016Sch Exam Mark Acc Sig'
  60. '2016Sch Exam Average %'
  61. '2016Sch Exam Standard Deviation %'
unique(school_df$`Diploma Course`)
  1. 'Biology 30'
  2. 'Chemistry 30'
  3. 'English Lang Arts 30-1'
  4. 'English Lang Arts 30-2'
  5. 'Mathematics 30-1'
  6. 'Mathematics 30-2'
  7. 'Physics 30'
  8. 'Science 30'
  9. 'Social Studies 30-1'
  10. 'Social Studies 30-2'
  11. 'French Lang Arts 30-1'
  12. 'Français 30-1'
  13. '1. The 2012/2013 results do not include students who were exempted from writing the examination because of the flooding in Calgary and southern Alberta.2. The 2015/2016 results do not include students who were exempted from writing the exam because of the Fort McMurray wildfires.3. +,=,- The percentage of students meeting the standard is significantly above (+), not significantly different from (=), or significantly below (-) the previous three-year average. A difference is reported as significant when there is a 5% or smaller probability that a difference of that size could occur by chance. The fewer the number of students, the larger the difference must be from the expectation before it is considered significant. Significance is not calculated for fewer than 6 students.'

Unsuprisingly, there’s a lot here! Let’s focus in - specifically, let’s only look at the Mathematics 30-1 course in 2016. I first taught that course in 2017, so the results here are of some personal interest to me.

relevant_schools_df <- school_df[school_df$`Diploma Course` == 'Mathematics 30-1', c("Diploma Course", "Authority Type", "Authority Code", "Authority Name", "School Code", "School Name", "2016Sch Students Writing", "2016Sch Exam Average %", "2016Sch Exam Standard Deviation %")]

# remove any where no students wrote
relevant_schools_df <- relevant_schools_df[relevant_schools_df$`2016Sch Students Writing` != 'n/a', ]

# convert to numeric
relevant_schools_df$`2016Sch Students Writing` <- as.integer(relevant_schools_df$`2016Sch Students Writing`)
relevant_schools_df$`2016Sch Exam Average %` <- as.numeric(relevant_schools_df$`2016Sch Exam Average %`)
relevant_schools_df$`2016Sch Exam Standard Deviation %` <- as.numeric(relevant_schools_df$`2016Sch Exam Standard Deviation %`)

# codes should be factors
relevant_schools_df$`Authority Code` <- as.factor(relevant_schools_df$`Authority Code`)
relevant_schools_df$`School Code` <- as.factor(relevant_schools_df$`School Code`)

head(relevant_schools_df)
A tibble: 6 × 9
Diploma CourseAuthority TypeAuthority CodeAuthority NameSchool CodeSchool Name2016Sch Students Writing2016Sch Exam Average %2016Sch Exam Standard Deviation %
<chr><chr><fct><chr><fct><chr><int><dbl><dbl>
Mathematics 30-1Charter 0009Foundations for the Future Charter Academy Charter School Society 0012FFCA High School Campus 11363.920.0
Mathematics 30-1Private 0015Webber Academy Foundation 0021Webber Academy 5984.810.8
Mathematics 30-1Separate0019Red Deer Catholic Regional Division No. 39 1272St. Dominic High School 1263.315.5
Mathematics 30-1Separate0019Red Deer Catholic Regional Division No. 39 4471Ecole Secondaire Notre Dame High School19061.820.0
Mathematics 30-1Separate0019Red Deer Catholic Regional Division No. 39 4483St. Gabriel Cyber School 2638.923.0
Mathematics 30-1Separate0020St. Thomas Aquinas Roman Catholic Separate Regional Division No. 381328Holy Trinity Academy 1049.022.4

Let’s graph the school results to get a feel for what the data look like:

plot_df <- data.frame(
    school = relevant_schools_df$`School Code`,
    authority = relevant_schools_df$`Authority Code`,
    mean = relevant_schools_df$`2016Sch Exam Average %`,
    lower = relevant_schools_df$`2016Sch Exam Average %` - relevant_schools_df$`2016Sch Exam Standard Deviation %`,
    upper = relevant_schools_df$`2016Sch Exam Average %` + relevant_schools_df$`2016Sch Exam Standard Deviation %`
)
# sort by mean, ascending
plot_df <- plot_df[order(plot_df$mean), ]

# should be plotted in the order of the mean
plot_df$school <- factor(plot_df$school, levels = plot_df$school)

ggplot(plot_df, aes(school, mean)) +
    geom_point() +
    geom_errorbar(aes(ymin = lower, ymax = upper)) +
    labs(x = "School", y = "Mean Score (%)") +
    theme_bw() +
    theme(
        panel.grid.major = element_blank(),
        axis.text.x = element_blank(), axis.ticks.x = element_blank()
    )

png

There’s definitely an interesting shape to the data! I am a bit shocked by the number of schools with extremely low averages, but such is life I suppose. Let’s see if colouring by the school authority reveals anything.

ggplot(plot_df, aes(school, mean, colour = authority)) +
    geom_point() +
    geom_errorbar(aes(ymin = lower, ymax = upper)) +
    labs(x = "School", y = "Mean Score", colour = "Authority") +
    theme_bw() +
    theme(
        panel.grid.major = element_blank(),
        axis.text.x = element_blank(), axis.ticks.x = element_blank()
    )

png

This is some glorious nonsense! Let’s try creating a crude average for the different authorities and plotting that.

authorities_df <- data.frame(authority = unique(relevant_schools_df$`Authority Code`))
authority_means <- sapply(authorities_df$authority, function(authority) {
    authority_schools <- relevant_schools_df[relevant_schools_df$`Authority Code` == authority, ]
    authority_mean <- sum(authority_schools$`2016Sch Exam Average %` * authority_schools$`2016Sch Students Writing`) / sum(authority_schools$`2016Sch Students Writing`)
    authority_mean
})
authority_num_schools <- sapply(authorities_df$authority, function(authority) {
    nrow(relevant_schools_df[relevant_schools_df$`Authority Code` == authority, ])
})

authorities_df$mean <- authority_means
authorities_df$num_schools <- authority_num_schools
# # plot these, ascending
authorities_df <- authorities_df[order(authorities_df$mean), ]
authorities_df$authority <- factor(authorities_df$authority, levels = authorities_df$authority)
ggplot(authorities_df, aes(authority, mean)) +
    geom_point(aes(size = num_schools)) +
    labs(x = "Authority", y = "Mean Score") +
    theme(axis.text.x = element_blank(), axis.ticks.x = element_blank()) +
    theme_bw() +  # Set background to white
    theme(
        panel.grid.major.x = element_blank(),
        panel.grid.minor.x = element_blank(),
        panel.grid.major.y = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
        axis.ticks.x = element_blank(),
        axis.text.x = element_blank(),
    )

png

So this looks roughly the same as the earlier graph of the per-school average, except now we’re plotting the per-authority average.

Great! So now let’s try to create a hierarchical model for the results based on the school and the school based on the authority. There is one wrinkle here, which is that for the data that we have, we don’t get the individual student data (unfortunately), which is different from what we’ve considered before. Instead, we’re going to have to create a likelihood from the data which we do have, which is the final mean, standard deviation, and number of students for each school. We’ll also remove any schools where fewer than 30 students wrote the exam.

So here’s our new approach:

Great! So now let’s filter the data.

MIN_STUDENTS <- 30
filtered_schools_df <- relevant_schools_df[relevant_schools_df['2016Sch Students Writing'] >= MIN_STUDENTS, ]
print(paste('Original number of schools:', nrow(relevant_schools_df), 'Filtered number of schools:', nrow(filtered_schools_df)))
[1] "Original number of schools: 317 Filtered number of schools: 140"
# plot the schools, ascending
plot_df <- data.frame(
    school = filtered_schools_df$`School Code`,
    authority = filtered_schools_df$`Authority Code`,
    students = filtered_schools_df$`2016Sch Students Writing`,
    mean = filtered_schools_df$`2016Sch Exam Average %`,
    lower = filtered_schools_df$`2016Sch Exam Average %` - filtered_schools_df$`2016Sch Exam Standard Deviation %`,
    upper = filtered_schools_df$`2016Sch Exam Average %` + filtered_schools_df$`2016Sch Exam Standard Deviation %`
)
# sort by mean, ascending
plot_df <- plot_df[order(plot_df$mean), ]

# should be plotted in the order of the mean
plot_df$school <- factor(plot_df$school, levels = plot_df$school)

ggplot(plot_df, aes(school, mean)) +
    geom_point(aes(size = students)) +
    geom_errorbar(aes(ymin = lower, ymax = upper)) +
    labs(x = "School", y = "Mean Score") +
    theme_bw() +
    theme(
        panel.grid.major.x = element_blank(),
        panel.grid.minor.x = element_blank(),
        panel.grid.major.y = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
        axis.ticks.x = element_blank(),
        axis.text.x = element_blank(),
    )

# now handle the authorities
authorities_df <- data.frame(authority = unique(filtered_schools_df$`Authority Code`))
authority_means <- sapply(authorities_df$authority, function(authority) {
    authority_schools <- filtered_schools_df[filtered_schools_df$`Authority Code` == authority, ]
    authority_mean <- sum(authority_schools$`2016Sch Exam Average %` * authority_schools$`2016Sch Students Writing`) / sum(authority_schools$`2016Sch Students Writing`)
    authority_mean
})
authority_num_schools <- sapply(authorities_df$authority, function(authority) {
    nrow(filtered_schools_df[filtered_schools_df$`Authority Code` == authority, ])
})


authorities_df$mean <- authority_means
authorities_df$num_schools <- authority_num_schools
# # plot the authorities, ascending
authorities_df <- authorities_df[order(authorities_df$mean), ]
authorities_df$authority <- factor(authorities_df$authority, levels = authorities_df$authority)
ggplot(authorities_df, aes(authority, mean)) +
    geom_point(aes(size = num_schools)) +
    labs(x = "Authority", y = "Mean Score", size = "Number of Schools") +
    theme_bw() +
    theme(
        panel.grid.major.x = element_blank(),
        panel.grid.minor.x = element_blank(),
        panel.grid.major.y = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
        axis.ticks.x = element_blank(),
        axis.text.x = element_blank(),
    )

png

png

By looking at the graphs, we see that the underlying shape hasn’t changed too much by removing the schools with fewer students; we’ve just trimmed it down a bit. Now let’s build the model! As before, we’ll start with the unpooled model, then pool all of the schools together, then finally pooly by district.

# convert the authorities to indices
authority_df <- data.frame(
    authority = unique(filtered_schools_df$`Authority Code`),
    index = seq_along(1:length(unique(filtered_schools_df$`Authority Code`)))
)
authority_index <- authority_df$index[match(filtered_schools_df$`Authority Code`, authority_df$authority)]
alberta_schools_data <- list(
    N_schools = nrow(filtered_schools_df),
    N_authorities = length(unique(filtered_schools_df$`Authority Code`)),
    authority_index = authority_index,
    school_mean = filtered_schools_df$`2016Sch Exam Average %` / 100,
    school_s_squared = (filtered_schools_df$`2016Sch Exam Standard Deviation %` / 100)^2,
    school_n = filtered_schools_df$`2016Sch Students Writing`
)

run_schools_model <- function(model_code_string) {
    model_file <- write_stan_file(model_code_string)
    model <- cmdstan_model(model_file)
    model_fit <- model$sample(
        data = alberta_schools_data,
        chains = 4,
        iter_warmup = 1000,
        adapt_delta = 0.99
    )
    model_fit
}

alberta_schools_unpooled_code <- "
data {
    int<lower = 0> N_schools;
    array[N_schools] real<lower = 0, upper = 1> school_mean;
    array[N_schools] int<lower = 0> school_n;
}
parameters {
    array[N_schools] real<lower = 0, upper = 1> school_mu;
    array[N_schools] real<lower = 0> school_nu;
}
transformed parameters {
    array[N_schools] real<lower = 0> school_alpha;
    array[N_schools] real<lower = 0> school_beta;
    array[N_schools] real<lower = 0> school_sigma_squared;
    for (i in 1:N_schools) {
        school_alpha[i] = school_mu[i] * school_nu[i];
        school_beta[i] = (1 - school_mu[i]) * school_nu[i];
        school_sigma_squared[i] = (school_alpha[i] * school_beta[i]) / (( school_alpha[i] + school_beta[i])^2 * (school_alpha[i] + school_beta[i] + 1));
    }
}
model {
    for (i in 1:N_schools) {
        school_mean[i] ~ normal(school_mu[i], sqrt(school_sigma_squared[i] / school_n[i]));
    }
    school_mu ~ beta(3, 1);
    school_nu ~ normal(30, 4);
}
"

alberta_schools_unpooled_fit <- run_schools_model(alberta_schools_unpooled_code)
alberta_schools_unpooled_fit$summary()
Running MCMC with 4 sequential chains...

All 4 chains finished successfully.
Mean chain execution time: 3.5 seconds.
Total execution time: 14.4 seconds.
A draws_summary: 701 × 10
variablemeanmediansdmadq5q95rhatess_bulkess_tail
<chr><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl>
lp__ 649.5773287649.775500012.20103306512.312993000629.2904000669.05760001.0029001 1430.9232053.862
school_mu[1] 0.6392039 0.6393735 0.008405132 0.008324058 0.6248339 0.65277701.0023608 9861.8313085.230
school_mu[2] 0.8478341 0.8478955 0.008545827 0.008547189 0.8338857 0.86136021.002653110923.6452489.053
school_mu[3] 0.6181320 0.6181990 0.006292014 0.006273622 0.6079128 0.62827941.0005929 9275.4172710.679
school_mu[4] 0.6221342 0.6222265 0.008266059 0.008361123 0.6085825 0.63537361.0003571 8557.7122885.691
school_mu[5] 0.6091299 0.6094770 0.010714624 0.010414524 0.5915248 0.62701321.0001241 9395.2562746.722
school_mu[6] 0.5643965 0.5645085 0.012238623 0.012046125 0.5443312 0.58433141.001331710665.3112956.641
school_mu[7] 0.7579158 0.7584285 0.013581550 0.013060965 0.7342792 0.77989351.000375011019.4262544.310
school_mu[8] 0.6613427 0.6616730 0.015752203 0.015415333 0.6350851 0.68660071.0003204 9957.8123062.898
school_mu[9] 0.5039022 0.5035950 0.015266937 0.015177376 0.4785290 0.52878711.0001035 9218.5112836.723
school_mu[10] 0.5495140 0.5494805 0.010088190 0.010011998 0.5330208 0.56634861.0020066 9635.4172255.971
school_mu[11] 0.6074124 0.6077090 0.013410985 0.013212931 0.5844894 0.62874701.001303910233.5352432.842
school_mu[12] 0.6794369 0.6794635 0.014992119 0.014513913 0.6545396 0.70372161.0017366 9346.9252523.928
school_mu[13] 0.7769397 0.7770205 0.008391946 0.008506418 0.7629991 0.79045301.000679810609.3493094.103
school_mu[14] 0.5941705 0.5941310 0.009878963 0.009741423 0.5780519 0.61064470.9999009 9063.3772557.421
school_mu[15] 0.5822358 0.5823065 0.010688901 0.010906747 0.5647210 0.59957710.9998798 8514.3712849.397
school_mu[16] 0.5612107 0.5611500 0.009223588 0.008926735 0.5459749 0.57640081.0015537 9275.3072389.068
school_mu[17] 0.6562535 0.6564125 0.009768611 0.009671000 0.6395904 0.67209521.0010295 8951.1842822.084
school_mu[18] 0.7020372 0.7021400 0.005097796 0.005126831 0.6936178 0.71031561.0021930 8503.0813030.269
school_mu[19] 0.5692039 0.5691785 0.007228040 0.007006026 0.5572946 0.58115641.0009770 9479.5632602.059
school_mu[20] 0.6291591 0.6291845 0.007741292 0.007446359 0.6163890 0.64217331.0009216 9699.6312890.812
school_mu[21] 0.5291810 0.5291255 0.009326573 0.009361878 0.5138633 0.54441111.0025134 9776.3792905.082
school_mu[22] 0.6482721 0.6483485 0.010462342 0.010331498 0.6312333 0.66530921.0003151 8538.6962790.360
school_mu[23] 0.5894198 0.5894010 0.010529789 0.010301846 0.5720068 0.60649431.0013085 9797.3342746.159
school_mu[24] 0.5411187 0.5410025 0.007940582 0.007633166 0.5281375 0.55401331.001724011054.7482125.300
school_mu[25] 0.5883560 0.5882275 0.010369037 0.010146173 0.5711747 0.60538311.0002894 9185.7692727.709
school_mu[26] 0.5070336 0.5072020 0.015308408 0.014958693 0.4821989 0.53203831.0011184 9601.2982389.797
school_mu[27] 0.6153622 0.6153640 0.015527387 0.015326377 0.5900032 0.64087601.0014148 9793.6792783.556
school_mu[28] 0.6994129 0.6996925 0.014865654 0.014537634 0.6749133 0.72323191.0053603 9952.0722730.345
school_mu[29] 0.6951703 0.6954805 0.014038333 0.014172915 0.6723685 0.71791551.0009009 9118.1532701.306
school_sigma_squared[111]0.0078932640.0077681900.00106977380.00100813830.0064390760.0098004321.0004525 7428.4002527.758
school_sigma_squared[112]0.0079691890.0078322250.00109261290.00102910230.0064280790.0099309321.002082510505.9732602.756
school_sigma_squared[113]0.0076185960.0075071750.00104812820.00101306800.0061576250.0095205021.0014516 8158.2732472.058
school_sigma_squared[114]0.0079644970.0078425000.00106748340.00100130360.0064642470.0098531041.0000604 8319.3912477.557
school_sigma_squared[115]0.0078303890.0076840800.00108595030.00101409840.0063111800.0098154781.0037708 8275.3812041.495
school_sigma_squared[116]0.0076435990.0075252350.00101802350.00094466080.0062058530.0095195441.0012818 9034.5273076.640
school_sigma_squared[117]0.0070790240.0069567200.00096948780.00089536440.0057483520.0088715711.0016213 9660.7732506.069
school_sigma_squared[118]0.0079781550.0078384500.00107659930.00100810870.0064651490.0099114541.0020249 8893.4352733.333
school_sigma_squared[119]0.0074359980.0073309550.00099795640.00092916770.0059934800.0092256921.0004015 8705.2082244.065
school_sigma_squared[120]0.0070031570.0068783350.00096837270.00089458600.0056935530.0087716051.000661210321.5912666.727
school_sigma_squared[121]0.0078458390.0077336200.00103104630.00095635110.0064297960.0096308301.0029224 9922.7612363.052
school_sigma_squared[122]0.0078799700.0077175400.00110861890.00099847180.0063554230.0098842131.0027917 9441.0052877.827
school_sigma_squared[123]0.0070905190.0069844450.00096577780.00092402300.0057473430.0089320011.001061810143.5932668.185
school_sigma_squared[124]0.0081511950.0080040900.00112423170.00104091860.0066032030.0101959300.9999701 9655.2582639.273
school_sigma_squared[125]0.0081428190.0080053250.00109496680.00102050320.0066586870.0101094101.002068210818.9402288.183
school_sigma_squared[126]0.0077197220.0076193800.00102020340.00097898300.0062920840.0095890831.0026693 7682.8362778.122
school_sigma_squared[127]0.0080473160.0079174650.00106276310.00100921320.0065595940.0099843891.003070310158.7182560.377
school_sigma_squared[128]0.0076391930.0075191700.00101048330.00092448260.0062017550.0094440581.0009677 9342.8762821.260
school_sigma_squared[129]0.0082002990.0080427450.00113324360.00103390590.0066236800.0102618050.9997888 9525.5842875.477
school_sigma_squared[130]0.0078057960.0076586800.00109143950.00096101390.0063339330.0097774021.0011799 6632.3282140.112
school_sigma_squared[131]0.0072094830.0070930050.00100739300.00091628390.0058414210.0091008031.0009968 8307.8312324.168
school_sigma_squared[132]0.0079632930.0078204750.00109806820.00099298620.0064252540.0099452421.0013978 9306.1962473.790
school_sigma_squared[133]0.0081997320.0080620100.00114168510.00105419530.0066166170.0102590801.0001760 8278.8642319.211
school_sigma_squared[134]0.0076852280.0075566400.00104387960.00097475020.0062442910.0095540431.0022049 7625.1782862.585
school_sigma_squared[135]0.0079941920.0078537650.00109246120.00101122960.0065057520.0100285701.0015995 7389.7142102.009
school_sigma_squared[136]0.0080935610.0079558800.00114031090.00104558880.0065411220.0100906351.0010682 9287.7992156.026
school_sigma_squared[137]0.0057504390.0056643750.00078424570.00075836470.0046635660.0071390220.9996415 9412.0672355.305
school_sigma_squared[138]0.0066522180.0065319600.00088367560.00082763920.0054367070.0082544241.0006992 8224.7992692.735
school_sigma_squared[139]0.0052458040.0051431500.00075345170.00070904600.0042106880.0066309241.0002898 9417.6382560.559
school_sigma_squared[140]0.0073393700.0072293200.00099399390.00093377110.0059634150.0091170670.9999096 9246.5062855.895
generate_school_plot <- function(fit) {
    draws <- as_draws_matrix(fit)

    plot_df <- data.frame(
        school = filtered_schools_df$`School Code`,
        students = filtered_schools_df$`2016Sch Students Writing`,
        observed_mean = filtered_schools_df$`2016Sch Exam Average %` / 100,
        observed_lower = filtered_schools_df$`2016Sch Exam Average %` / 100 - filtered_schools_df$`2016Sch Exam Standard Deviation %` / 100,
        observed_upper = filtered_schools_df$`2016Sch Exam Average %` / 100 + filtered_schools_df$`2016Sch Exam Standard Deviation %` / 100,
        model_mean = colMeans(draws[, grep('school_mu', colnames(draws))]),
        model_lower = apply(draws[, grep('school_mu', colnames(draws))], 2, function(col) quantile(col, 0.025)),
        model_upper = apply(draws[, grep('school_mu', colnames(draws))], 2, function(col) quantile(col, 0.975))
    )
    plot_df <- plot_df[order(plot_df$observed_mean), ]
    plot_df$school <- factor(plot_df$school, levels = plot_df$school)
    # plot the empirical data against the estimated means
    p <- ggplot(plot_df) +
        geom_point(aes(school, observed_mean, size = students, colour = "Observed")) +
        geom_point(aes(school, model_mean, colour = 'Model')) +
        geom_errorbar(aes(x = school, y = model_mean, ymin = model_lower, ymax = model_upper, colour = "Model")) +
        labs(x = "School", y = "Mean Score", size = "Number of Students", colour = "") +
        scale_y_continuous(labels = scales::percent_format(accuracy = 1)) +
        theme_bw() +  # Set background to white
        theme(
            panel.grid.major.x = element_blank(),
            panel.grid.minor.x = element_blank(),
            panel.grid.major.y = element_line(size = 0.5, linetype = 'dashed', colour = "grey"),
            axis.ticks.x = element_blank(),
            axis.text.x = element_blank(),
        )
    p
}

unpooled_school_means_plot <- generate_school_plot(alberta_schools_unpooled_fit)
print(unpooled_school_means_plot)

png

That looks pretty good! Now let’s pool all of the schools together using a shared prior.

alberta_schools_pooled_code <- "
data {
    int<lower = 0> N_schools;
    array[N_schools] real<lower = 0, upper = 1> school_mean;
    array[N_schools] int<lower = 0> school_n;
}
parameters {
    array[N_schools] real<lower = 0, upper = 1> school_mu;
    array[N_schools] real<lower = 0> school_nu_raw; // non-centred - implies school_nu ~ normal(school_prior_nu_mean, school_prior_nu_sd)

    real <lower = 0> school_prior_alpha;
    real <lower = 0> school_prior_beta;
    real <lower = 0> school_prior_nu_mean;
    real <lower = 0> school_prior_nu_sd;
}
transformed parameters {
    array[N_schools] real<lower = 0> school_alpha;
    array[N_schools] real<lower = 0> school_beta;
    array[N_schools] real<lower = 0> school_sigma_squared;
    array[N_schools] real<lower = 0> school_nu;
    for (i in 1:N_schools) {
        school_nu[i] = school_prior_nu_mean + school_prior_nu_sd * school_nu_raw[i];
        school_alpha[i] = school_mu[i] * school_nu[i];
        school_beta[i] = (1 - school_mu[i]) * school_nu[i];
        school_sigma_squared[i] = (school_alpha[i] * school_beta[i]) / (( school_alpha[i] + school_beta[i])^2 * (school_alpha[i] + school_beta[i] + 1));
    }
}
model {
    for (i in 1:N_schools) {
        school_mean[i] ~ normal(school_mu[i], sqrt(school_sigma_squared[i] / school_n[i]));
    }
    school_mu ~ beta(school_prior_alpha, school_prior_beta);
    school_nu_raw ~ std_normal();

    school_prior_alpha ~ normal(3, 0.5);
    school_prior_beta ~ normal(1, 0.3);
    school_prior_nu_mean ~ normal(30, 4);
    school_prior_nu_sd ~ exponential(1);
}
"

alberta_schools_pooled_fit <- run_schools_model(alberta_schools_pooled_code)
alberta_schools_pooled_fit$summary()
Running MCMC with 4 sequential chains...

All 4 chains finished successfully.
Mean chain execution time: 7.3 seconds.
Total execution time: 29.3 seconds.
A draws_summary: 845 × 10
variablemeanmediansdmadq5q95rhatess_bulkess_tail
<chr><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl>
lp__ 303.3768817303.949000016.01060995315.280416900276.3478500328.52500001.0015113 893.82681506.195
school_mu[1] 0.6390792 0.6389565 0.008087173 0.008065344 0.6256816 0.65245801.00245934588.58023114.740
school_mu[2] 0.8472796 0.8473785 0.008252892 0.008267719 0.8334421 0.86064131.00095094664.16562717.473
school_mu[3] 0.6180220 0.6181105 0.006271633 0.006166133 0.6076868 0.62831911.00173904661.05683093.812
school_mu[4] 0.6220444 0.6221355 0.008126488 0.008079429 0.6084968 0.63526121.00188284967.94533136.635
school_mu[5] 0.6090356 0.6091480 0.010636975 0.010635431 0.5912527 0.62622071.00025084984.41102805.612
school_mu[6] 0.5643344 0.5645605 0.011741694 0.011797048 0.5449506 0.58386781.00069324318.72142945.927
school_mu[7] 0.7572024 0.7575620 0.013173178 0.013160299 0.7349491 0.77837560.99990794704.42222625.790
school_mu[8] 0.6605515 0.6607250 0.015465255 0.015314517 0.6343163 0.68581141.00028154391.08702871.989
school_mu[9] 0.5038833 0.5039180 0.015056639 0.015186272 0.4792564 0.52817371.00107995395.63172740.259
school_mu[10] 0.5492591 0.5494235 0.010111278 0.010022376 0.5324129 0.56591951.00129274671.22672758.982
school_mu[11] 0.6073463 0.6074625 0.013313283 0.013341917 0.5854852 0.62908951.00200004909.18982668.409
school_mu[12] 0.6782184 0.6787230 0.014706708 0.015006136 0.6538240 0.70144321.00128394657.26812982.858
school_mu[13] 0.7765780 0.7767805 0.008369508 0.008294406 0.7624945 0.79039531.00221974809.66452783.120
school_mu[14] 0.5941792 0.5943005 0.009515127 0.009766628 0.5786709 0.60977081.00235995141.71902850.491
school_mu[15] 0.5822526 0.5823935 0.010862116 0.010788880 0.5644139 0.60004140.99977105552.03713009.097
school_mu[16] 0.5612144 0.5611465 0.009131184 0.009202498 0.5463539 0.57582371.00227814576.87982877.150
school_mu[17] 0.6560447 0.6561910 0.010477979 0.010462708 0.6388660 0.67281941.00206444727.28552872.770
school_mu[18] 0.7019719 0.7019980 0.004950784 0.004878495 0.6937125 0.71017381.00000354318.30773191.356
school_mu[19] 0.5690055 0.5690850 0.007141884 0.007216556 0.5573088 0.58039890.99958994456.19633119.883
school_mu[20] 0.6287745 0.6287395 0.007553180 0.007547175 0.6164030 0.64124370.99948054260.82543093.037
school_mu[21] 0.5293484 0.5295380 0.009045090 0.009111318 0.5146197 0.54417551.00061015183.61443166.340
school_mu[22] 0.6478409 0.6479130 0.010385516 0.010494584 0.6308290 0.66435481.00055234699.08933190.544
school_mu[23] 0.5893679 0.5892725 0.010550171 0.010570197 0.5720445 0.60674841.00178355601.27553042.860
school_mu[24] 0.5410172 0.5408530 0.008148151 0.008352227 0.5278048 0.55447350.99931664670.57562990.563
school_mu[25] 0.5877308 0.5877460 0.009933401 0.009825932 0.5711710 0.60401761.00081394943.86282533.125
school_mu[26] 0.5065683 0.5063280 0.015377681 0.015707406 0.4820062 0.53209321.00224375312.73532943.792
school_mu[27] 0.6150520 0.6149875 0.015574998 0.015925348 0.5894119 0.64025021.00024204431.53762993.436
school_mu[28] 0.6987703 0.6987035 0.014475063 0.014868254 0.6752261 0.72225091.00111574728.14482926.576
school_mu[29] 0.6943627 0.6943935 0.013998895 0.013949783 0.6707471 0.71698561.00066564834.94272654.505
school_nu[111]30.7809030.823854.1014064.04164224.0322837.586281.00057711472.5801875.910
school_nu[112]30.8021230.710154.1554574.10124223.9999437.528871.00037631475.4031967.981
school_nu[113]30.7888130.765954.1284694.11777324.1326237.501740.99992771471.8021808.811
school_nu[114]30.8018930.789604.1304284.06973724.0618737.536221.00027361472.1351782.622
school_nu[115]30.7873830.756954.1157374.06558624.0541937.541051.00016341463.7301951.738
school_nu[116]30.7612430.756654.1300474.09642424.0219037.547411.00011401467.2141807.532
school_nu[117]30.7873230.770254.1055114.03586024.0620237.466410.99983911457.5591697.574
school_nu[118]30.7796330.774054.1184214.07707624.0352937.624131.00002451453.8701776.225
school_nu[119]30.7865130.780554.1387594.05535624.0646437.463531.00011191451.4471861.304
school_nu[120]30.8051830.832104.1367194.04386624.0247137.536021.00032451428.1961799.062
school_nu[121]30.7908430.747404.1323594.05972924.0994037.461451.00027361447.1101900.183
school_nu[122]30.7676830.799254.0952444.01881024.0920937.469931.00003501452.1881818.884
school_nu[123]30.7958830.800654.1152334.05676424.0665837.542590.99987841483.7321780.509
school_nu[124]30.7898330.809704.1400704.11332524.0452137.546121.00036221448.0441807.375
school_nu[125]30.7945730.799904.1194914.05876624.0919637.570051.00030711460.1751883.421
school_nu[126]30.7805330.813704.1100784.08930724.1229737.501831.00049381418.4501938.605
school_nu[127]30.7911230.800104.1371644.09190224.1310937.514911.00018931456.5661874.482
school_nu[128]30.7824230.793704.0997994.03541524.0848637.607621.00016701452.2101833.954
school_nu[129]30.7761230.773054.0959184.02029224.0806937.461921.00027091454.9491808.966
school_nu[130]30.7920630.780954.1353954.10339224.0483537.428860.99988661433.6561889.972
school_nu[131]30.7963930.815804.1065774.07411124.0470537.461380.99988061474.3171819.319
school_nu[132]30.7846430.797904.1112464.04082624.0200137.515790.99984441496.8961940.522
school_nu[133]30.7974930.785654.1148454.09086424.0949337.539691.00002611491.2181883.304
school_nu[134]30.7897030.759504.1468624.15431924.0222237.487861.00033821421.4171925.859
school_nu[135]30.7719130.772154.1632014.09019724.0024137.473921.00079631412.1771826.859
school_nu[136]30.8071130.776204.1341484.10606124.0366137.543060.99997991442.2951972.031
school_nu[137]30.8103930.782804.1137714.06247224.0563137.484750.99997881464.9401851.897
school_nu[138]30.7784930.776954.1069984.02948424.1075937.620791.00073961448.0661733.536
school_nu[139]30.7570330.808154.1176194.07144224.0290637.453610.99998841477.2641725.516
school_nu[140]30.7772630.806154.1087714.06365824.0488737.556351.00006921453.9011852.164
draws <- as_draws_matrix(alberta_schools_pooled_fit)
prior_mean <- mean(draws[, 'school_prior_alpha'] / (draws[, 'school_prior_alpha'] + draws[, 'school_prior_beta']))
pooled_school_means_plot <- generate_school_plot(alberta_schools_pooled_fit)
pooled_school_means_plot <- pooled_school_means_plot +
    geom_hline(yintercept = prior_mean, linetype = 'dashed')
print(pooled_school_means_plot)

png

Great! So again we’re doing a good job of capturing the data. This data is very similar to the unpooled estimate, probably because there’s not a lot of variation explained by the different schools. Now let’s pool all of the schools by district!


alberta_schools_authority_pooled_code <- "
data {
    int<lower = 0> N_schools;
    int<lower = 0> N_authorities;
    array[N_schools] real<lower = 0, upper = 1> school_mean;
    array[N_schools] int<lower = 0> school_n;
    array[N_schools] int<lower = 1, upper = N_authorities> authority_index;
}
parameters {
    array[N_schools] real<lower = 0, upper = 1> school_mu;
    array[N_schools] real<lower = 0> school_nu_raw;

    real<lower = 0> authority_prior_alpha_mean;
    real<lower = 0> authority_prior_alpha_sd;
    array[N_authorities] real<lower = 0> school_prior_alpha_raw;

    real<lower = 0> authority_prior_beta_mean;
    real<lower = 0> authority_prior_beta_sd;
    array[N_authorities] real<lower = 0> school_prior_beta_raw;

    real <lower = 0> school_prior_nu_mean;
    real <lower = 0> school_prior_nu_sd;
}
transformed parameters {
    array[N_schools] real<lower = 0> school_alpha;
    array[N_schools] real<lower = 0> school_beta;
    array[N_schools] real<lower = 0> school_sigma_squared;
    array[N_schools] real<lower = 0> school_nu;
    array[N_authorities] real<lower = 0> school_prior_alpha;
    array[N_authorities] real<lower = 0> school_prior_beta;

    for (i in 1:N_authorities) {
        school_prior_alpha[i] = authority_prior_beta_mean + authority_prior_beta_sd * school_prior_beta_raw[i];
        school_prior_beta[i] = authority_prior_alpha_mean + authority_prior_alpha_sd * school_prior_alpha_raw[i];
    }
    for (i in 1:N_schools) {
        school_nu[i] = school_prior_nu_mean + school_prior_nu_sd * school_nu_raw[i];
        school_alpha[i] = school_mu[i] * school_nu[i];
        school_beta[i] = (1 - school_mu[i]) * school_nu[i];
        school_sigma_squared[i] = (school_alpha[i] * school_beta[i]) / (( school_alpha[i] + school_beta[i])^2 * (school_alpha[i] + school_beta[i] + 1));
    }
}
model {
    for (i in 1:N_schools) {
        school_mu[i] ~ beta(school_prior_alpha[authority_index[i]], school_prior_beta[authority_index[i]]);
        school_mean[i] ~ normal(school_mu[i], sqrt(school_sigma_squared[i] / school_n[i]));
    }
    school_nu_raw ~ std_normal(); // non-centred - implies school_nu ~ normal(school_prior_nu_mean, school_prior_nu_sd)

    // TODO create a shared prior for the hyperparameters alpha and beta
    authority_prior_alpha_mean ~ normal(3, 0.5);
    authority_prior_alpha_sd ~ exponential(1);
    authority_prior_beta_mean ~ normal(1, 0.3);
    authority_prior_beta_sd ~ exponential(1);

    school_prior_alpha_raw ~ std_normal();
    school_prior_beta_raw ~ std_normal();

    school_prior_nu_mean ~ normal(30, 4);
    school_prior_nu_sd ~ exponential(1);
}
"
alberta_schools_authority_fit <- run_schools_model(alberta_schools_authority_pooled_code)
alberta_schools_authority_fit$summary()
Running MCMC with 4 sequential chains...

All 4 chains finished successfully.
Mean chain execution time: 14.9 seconds.
Total execution time: 60.0 seconds.
A draws_summary: 1067 × 10
variablemeanmediansdmadq5q95rhatess_bulkess_tail
<chr><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl><dbl>
lp__ 242.7581233243.477500020.31946535419.953572100208.5067000275.42380001.0037834 866.7461539.503
school_mu[1] 0.6387844 0.6387405 0.007867415 0.007838506 0.6258293 0.65161801.00178757730.7823023.918
school_mu[2] 0.8468010 0.8471110 0.008234859 0.008496039 0.8331503 0.85977391.00192728346.5062944.039
school_mu[3] 0.6178737 0.6180900 0.006078236 0.006126845 0.6077802 0.62781741.00059217897.3633082.192
school_mu[4] 0.6220775 0.6221365 0.008252685 0.008003816 0.6084058 0.63559931.00352259066.3232903.928
school_mu[5] 0.6089325 0.6090580 0.010396388 0.010213631 0.5917600 0.62582831.00044918822.5242701.826
school_mu[6] 0.5651614 0.5651690 0.012056050 0.012080966 0.5459932 0.58436861.00167428920.3772358.078
school_mu[7] 0.7564602 0.7566715 0.013304278 0.013247031 0.7337154 0.77764361.00031128465.8912869.356
school_mu[8] 0.6607885 0.6612545 0.015222589 0.015111400 0.6355575 0.68521251.00038717745.0713102.663
school_mu[9] 0.5033227 0.5035615 0.014959102 0.014415320 0.4786364 0.52755421.00188427362.8042435.936
school_mu[10] 0.5489648 0.5489295 0.009962331 0.010085386 0.5327384 0.56547561.00093097858.6092480.194
school_mu[11] 0.6069848 0.6070440 0.013525705 0.012806699 0.5846459 0.62962541.00039537109.6602533.558
school_mu[12] 0.6785647 0.6788500 0.014326434 0.014209238 0.6543459 0.70167131.00134758032.3092550.878
school_mu[13] 0.7764559 0.7766735 0.008665420 0.008676916 0.7621878 0.79027241.00228188470.6862569.698
school_mu[14] 0.5940918 0.5941210 0.009433348 0.009140970 0.5780491 0.60958940.99946297222.5792868.847
school_mu[15] 0.5822622 0.5822365 0.010469228 0.010771089 0.5654695 0.59932731.00116688890.1722870.622
school_mu[16] 0.5611338 0.5611185 0.008928162 0.008979367 0.5461858 0.57597180.99997438275.0792942.549
school_mu[17] 0.6552331 0.6551580 0.010271325 0.010408593 0.6387280 0.67195181.00081118512.8142635.261
school_mu[18] 0.7017329 0.7018120 0.005121847 0.005091248 0.6931912 0.70998701.00111107544.9822962.647
school_mu[19] 0.5692686 0.5692310 0.007138546 0.007034937 0.5576670 0.58097181.00410926968.8752653.758
school_mu[20] 0.6290081 0.6291435 0.007791504 0.007737689 0.6161220 0.64182321.00182357827.9952969.588
school_mu[21] 0.5297266 0.5297355 0.009124199 0.009122438 0.5144827 0.54456611.00054997307.8252901.814
school_mu[22] 0.6477019 0.6477230 0.010225237 0.010181756 0.6304554 0.66463341.00123747897.9922363.115
school_mu[23] 0.5891082 0.5892005 0.010622697 0.010529425 0.5712892 0.60628001.00072247841.7152924.020
school_mu[24] 0.5413668 0.5413110 0.007999754 0.007971199 0.5283126 0.55468061.00148828305.3012780.133
school_mu[25] 0.5880985 0.5882500 0.010154796 0.009871151 0.5715252 0.60448751.00090687350.3202446.588
school_mu[26] 0.5064102 0.5062810 0.015608235 0.015106953 0.4805963 0.53248691.00054186929.8132712.310
school_mu[27] 0.6151159 0.6152560 0.015493198 0.015774864 0.5897355 0.64023941.00032657670.1342554.444
school_mu[28] 0.6982921 0.6985970 0.014558349 0.015005395 0.6741918 0.72195881.00007738525.2543075.386
school_mu[29] 0.6944085 0.6943485 0.013937068 0.013428649 0.6708920 0.71719051.00347619790.3192710.654
school_prior_beta[26]6.4444295.9455952.1763901.8123823.94281510.6829651.0003582368.8022829.265
school_prior_beta[27]6.2283405.6905552.1152531.6736483.87628610.2953951.0005662573.9892812.977
school_prior_beta[28]6.5641246.1598702.0463471.8580024.03238310.5879101.0025671435.1811861.186
school_prior_beta[29]6.6796536.1316202.3049411.9090854.03168111.0703401.0015542155.8472806.314
school_prior_beta[30]6.2154325.7796151.9492131.7177183.897775 9.9214741.0009912442.8332637.715
school_prior_beta[31]6.4054815.8402802.2137231.8316933.90427910.7452501.0021402487.8392745.415
school_prior_beta[32]6.3585795.7812952.2595891.7722703.87099210.7481801.0008212536.1842771.503
school_prior_beta[33]6.1044545.6847751.8976041.5997853.863385 9.8783261.0006182249.6362085.735
school_prior_beta[34]7.4585837.2250702.0042221.9177284.65326811.1368601.0034571731.3281917.717
school_prior_beta[35]6.9451146.7415951.8397311.8459194.45293210.3840551.0009701898.5792357.969
school_prior_beta[36]6.9025186.3547002.3915162.0969604.10183511.6062851.0023521846.9912092.054
school_prior_beta[37]6.4052855.8855452.1426511.8361563.93113410.5592101.0013562221.3622639.116
school_prior_beta[38]5.9618195.4859651.8815911.5088863.827024 9.6902201.0010432653.8972563.266
school_prior_beta[39]6.5118986.0201502.1616271.8852893.96476110.6358201.0019961721.1772229.240
school_prior_beta[40]6.4128685.9553802.1223351.8711673.94967510.3693501.0005982339.8742541.280
school_prior_beta[41]6.2772025.7785902.0659701.6989113.91015910.2282301.0012572913.9602810.728
school_prior_beta[42]7.7142677.2969602.4679242.3338054.51778712.3052351.0026251542.3961778.872
school_prior_beta[43]6.4910706.0003002.1829131.8730653.94143510.7624001.0014071999.2562222.189
school_prior_beta[44]6.1691965.6663352.0533891.6716313.86660310.0550401.0000852892.0042505.539
school_prior_beta[45]6.4469155.8808652.2549191.8638433.87913210.7662301.0003372587.1262728.151
school_prior_beta[46]6.2323845.7230252.0475021.6842713.90260810.3332351.0016482626.0353080.817
school_prior_beta[47]5.9646385.5127501.8837361.5324453.801014 9.7866581.0006393195.1782958.786
school_prior_beta[48]6.2433385.7483452.0889231.7254053.83631410.3849301.0017291967.2452267.657
school_prior_beta[49]6.4527985.9401152.2479941.8363853.95398410.7539801.0014212044.5212594.231
school_prior_beta[50]6.4081775.9422852.1593341.8730653.92217210.5778901.0023821653.9912504.470
school_prior_beta[51]6.3615285.8354702.1429531.7708923.92625510.5481251.0031252515.9242882.958
school_prior_beta[52]5.5371065.1808551.5630881.2536423.698441 8.5435661.0002313139.6062477.011
school_prior_beta[53]5.8094955.4074951.7595631.4629933.788838 9.2443791.0005702770.7582997.078
school_prior_beta[54]5.3816835.0219601.4920771.1805653.664767 8.2783101.0001863755.4972987.081
school_prior_beta[55]6.0319825.5527701.9689351.5815263.816465 9.8464261.0012082709.5142989.905
authority_pooled_school_means_plot <- generate_school_plot(alberta_schools_authority_fit)
print(authority_pooled_school_means_plot)

png

Again, it looks like this is working basically as expected!

Final Comments

And there we have it! Solving the same problem with various levels of pooling.

One thing that you probably noticed is that all of the results for the last problem were essentially identical! That is absolutely the case, and is perhaps an illustrative example of why you should test out examples before applying these methods. But why did the level of pooling we did seem to have so little effect? Essentially, we had too much data! Pooling works by affecting the prior, and since all of our schools had a large number of students, their data overwhelmed the prior, regardless of whether that prior was unique to the school, pooled among all of the schools, or pooled by authority. As a result, we didn’t see the same changes that we did in the earlier examples, where had had fewer data points. This is unsurprising in retrospect, but it also highlights that pooling has the greatest effect when you do not have an abundance of data.

You might also have noticed that the shape of the models changed as we added levels. That is because I had to reparameterize the models from centred to a non-centred for several parameters in order to avoid issues with the sampler (divergent transitions, &c.). I didn’t go into details here because the focus was on other parts of the process, but I wrote all of the pooled models several times, non-centring a single parameter, then running it to see if it fixed the process, and repeating until everything worked. Apparently it’s very common to have to do this with hierarchical models!

Hopefully this helps with understanding how to build hierarchical models! This is definitely something that I need to work through every time, starting from the unpooled version and then gradually adding in pooling. Despite the added level of complexity, the additional out-of-sampling accuracy and the additional confidence that that engenders in the model is worth it!

sessionInfo()
R version 4.4.1 (2024-06-14)
Platform: aarch64-apple-darwin23.4.0
Running under: macOS 15.0.1

Matrix products: default
BLAS:   /opt/homebrew/Cellar/openblas/0.3.28/lib/libopenblasp-r0.3.28.dylib
LAPACK: /opt/homebrew/Cellar/r/4.4.1/lib/R/lib/libRlapack.dylib;  LAPACK version 3.12.0

locale:
[1] en_CA/en_CA/en_CA/C/en_CA/en_CA

time zone: America/Edmonton
tzcode source: internal

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base

other attached packages:
[1] readxl_1.4.3    posterior_1.6.0 cmdstanr_0.8.1  ggplot2_3.5.1