library("nfidd")
library("dplyr")
library("tidyr")
library("ggplot2")
library("scoringutils")
Evaluating forecasts from multiple models
Introduction
We can classify models along a spectrum by how much they include an understanding of underlying processes, or mechanisms; or whether they emphasise drawing from the data using a statistical approach. These different approaches all have different strength and weaknesses, and it is not clear a prior which one produces the best forecast in any given situation.
In this session, we’ll start with forecasts from models of different levels of mechanism vs. statistical complexity and evaluate them using visualisation and proper scoring rules as we did in the last session for the random walk model
Slides
Objectives
The aim of this session is to introduce the concept of a spectrum of forecasting models and to demonstrate how to evaluate a range of different models from across this spectrum.
Source file
The source file of this session is located at sessions/forecast-ensembles.qmd
.
Libraries used
In this session we will use the nfidd
package to load a data set of infection times and access stan models and helper functions, the dplyr
and tidyr
packages for data wrangling, ggplot2
library for plotting, the tidybayes
package for extracting results of the inference and the scoringutils
package for evaluating forecasts.
The best way to interact with the material is via the Visual Editor of RStudio.
Initialisation
We set a random seed for reproducibility. Setting this ensures that you should get exactly the same results on your computer as we do.
set.seed(123)
Individual forecast models
In this session we will use the forecasts from different models. There all shared the same basic renewal with delays structure but used different models for the evolution of the effective reproduction number over time. These were:
- A random walk model (what we have looked at so far)
- A simple model of susceptible depletion referred to as “More mechanistic”
- A differenced autoregressive model referred to as “More statistical”
For the purposes of this session the precise details of the models are not critical to the concepts we are exploring.
One way to potentially improve the renewal model is to add more mechanistic structure. In the forecasting visualisation session, we saw that the renewal model was making unbiased forecasts when the reproduction number was constant but that it overestimated the number of cases when the reproduction number was reducing due to susceptible depletion.
This is slightly cheating as we know the future of this outbreak and so can make a model to match. This is easy to do and important to watch for if wanting to make generalisable methods.
This suggests that we should add a term to the renewal model which captures the depletion of susceptibles. One way to do this is to add a term which is proportional to the number of susceptibles in the population. This is the idea behind the SIR model which is a simple compartmental model of infectious disease transmission. If we assume that susceptible depletion is the only mechanism which is causing the reproduction number to change, we can write the reproduction model as:
\[ R_t = \frac{S_{t-1}}{N} R_0 \]
This approximates susceptible depletion as a linear function of the number of susceptibles in the population. This is a simplification but it is a good starting point.
<- 100
n <- 1000
N <- 1.5
R0 <- rep(NA, n)
S 1] <- N
S[<- rep(NA, n) ## reproduction number
Rt 1] <- R0
Rt[<- rep(NA, n)
I 1] <- 1
I[for (i in 2:n) {
<- (S[i-1]) / N * R0
Rt[i] <- I[i-1] * Rt[i]
I[i] <- S[i-1] - I[i]
S[i]
}
<- tibble(t = 1:n, Rt = Rt)
data
ggplot(data, aes(x = t, y = Rt)) +
geom_line() +
labs(title = "Simulated data from an SIR model",
x = "Time",
y = "Rt")
The key assumptions we are making here are:
- The population is constant and we roughly know the size of the population.
- The reproduction number only changes due to susceptible depletion
- The number of new cases at each time is proportional to the number of susceptibles in the population.
Adding more mechanistic structure is not always possible and, if we don’t specify mechanisms correctly, might make forecasts worse. Rather than adding more mechanistic structure to the renewal model, we could add more statistical structure with the aim of improving performance. Before we do this, we need to think about what we want from a forecasting model. As we identified above, we want a model which is unbiased and which has good short-term forecasting properties. We know that we want it to be able to adapt to trends in the reproduction number and that we want it to be able to capture the noise in the data. A statistical term that can be used to describe a time series with a trend is saying that the time series is non-stationary. More specifically, a stationary time series is defined as one whose statistical properties, such as mean and variance, do not change over time. In infectious disease epidemiology, this would only be expected for endemic diseases without external seasonal influence.
The random walk model we used in the forecasting visualisation session is a special case of a more general class of models called autoregressive (AR) models. AR models are a class of models which predict the next value in a time series as a linear combination of the previous values in the time series. The random walk model is specifically a special case of an AR(1) model where the next value in the time series is predicted as the previous value, multiplied by a value between 1 and -1 , plus some noise. This becomes a random walk when the multipled value is 0.
For the log-transformed reproduction number (\(log(R_t)\)), the model is:
\[ log(R_t) = \phi log(R_{t-1}) + \epsilon_t \]
where \(\epsilon_t\) is a normally distributed error term with mean 0 and standard deviation \(\sigma\) and \(\phi\) is a parameter between -1 and 1. If we restrict \(\phi\) to be between 0 and 1, we get a model which is biased towards a reproduction number of 1 but which can still capture trends in the data that decay over time.
<- 100
n <- 0.4
phi <- 0.1
sigma <- rep(NA, n)
log_R 1] <- rnorm(1, 0, sigma)
log_R[for (i in 2:n) {
<- phi * log_R[i-1] + rnorm(1, 0, sigma)
log_R[i]
}<- tibble(t = 1:n, R = exp(log_R))
data
ggplot(data, aes(x = t, y = R)) +
geom_line() +
labs(title = "Simulated data from an exponentiated AR(1) model",
x = "Time",
y = "R")
However, we probably don’t want a model which is biased towards a reproduction number of 1 (unless we have good reason to believe that is the expected behaviour). So what should we do?
Returning to the idea that the reproduction number is a non-stationary time series, as we expect to have a trend in the reproduction numbers we want to capture, we can use a method from the field of time series analysis called differencing to make the time series stationary. This involves taking the difference between consecutive values in the time series. For the log-transformed reproduction number, this would be:
\[ log(R_t) - log(R_{t-1}) = \phi (log(R_{t-1}) - log(R_{t-2})) + \epsilon_t \]
Again we look at an R function that implements this model:
geometric_diff_ar
function (init, noise, std, damp)
{
n <- length(noise) + 1
x <- numeric(n)
x[1] <- init
x[2] <- x[1] + noise[1] * std
for (i in 3:n) {
x[i] <- x[i - 1] + damp * (x[i - 1] - x[i - 2]) + noise[i -
1] * std
}
return(exp(x))
}
<bytecode: 0x563ce9094fd8>
<environment: namespace:nfidd>
We can use this function to simulate a differenced AR process:
<- geometric_diff_ar(init = 1, noise = rnorm(100), std = 0.1, damp = 0.1)
log_R
<- tibble(t = seq_along(log_R), R = exp(log_R))
data
ggplot(data, aes(x = t, y = R)) +
geom_line() +
labs(title = "Simulated data from an exponentiated AR(1) model with differencing",
x = "Time",
y = "R")
As previously, we have fitted these models to a range of forecast dates so you don’t have to wait for the models to fit. We will now evaluate the forecasts from the mechanistic and statistical models.
data(rw_forecasts, stat_forecasts, mech_forecasts)
<- bind_rows(
forecasts
rw_forecasts,mutate(stat_forecasts, model = "More statistical"),
mutate(mech_forecasts, model = "More mechanistic")
|>
) ungroup()
forecasts
# A tibble: 672,000 × 7
day .draw .variable .value horizon target_day model
<dbl> <int> <chr> <dbl> <int> <dbl> <chr>
1 23 1 forecast 4 1 22 Random walk
2 23 2 forecast 2 1 22 Random walk
3 23 3 forecast 2 1 22 Random walk
4 23 4 forecast 6 1 22 Random walk
5 23 5 forecast 2 1 22 Random walk
6 23 6 forecast 3 1 22 Random walk
7 23 7 forecast 5 1 22 Random walk
8 23 8 forecast 2 1 22 Random walk
9 23 9 forecast 3 1 22 Random walk
10 23 10 forecast 5 1 22 Random walk
# ℹ 671,990 more rows
Some important things to note about these forecasts:
- We used a 14 day forecast horizon.
- Each forecast used all the data up to the forecast date.
- We generated 1000 predictive posterior samples for each forecast.
- We started forecasting 3 weeks into the outbreak and then forecast every 7 days until the end of the data (excluding the last 14 days to allow a full forecast).
- We use the same simulated outbreak data as before:
<- make_gen_time_pmf()
gen_time_pmf <- make_ip_pmf()
ip_pmf <- simulate_onsets(
onset_df make_daily_infections(infection_times), gen_time_pmf, ip_pmf
)head(onset_df)
# A tibble: 6 × 3
day onsets infections
<dbl> <int> <int>
1 1 0 0
2 2 0 1
3 3 0 0
4 4 0 2
5 5 0 1
6 6 0 1
Visualising your forecast
As in the forecasting evaluation session, we will first visualise the forecasts across multiple forecast dates.
|>
forecasts filter(.draw %in% sample(.draw, 100)) |>
ggplot(aes(x = day)) +
geom_line(aes(y = .value, group = interaction(.draw, target_day), col = target_day), alpha = 0.1) +
geom_point(data = onset_df |>
filter(day >= 21),
aes(x = day, y = onsets), color = "black") +
scale_color_binned(type = "viridis") +
facet_wrap(~model) +
lims(y = c(0, 500))
Warning: Removed 31 rows containing missing values or values outside the scale range
(`geom_line()`).
As for the single forecast it is helpful to also plot the forecast on the log scale.
|>
forecasts filter(.draw %in% sample(.draw, 100)) |>
ggplot(aes(x = day)) +
geom_line(
aes(y = .value, group = interaction(.draw, target_day), col = target_day),
alpha = 0.1
+
) geom_point(data = onset_df, aes(x = day, y = onsets), color = "black") +
scale_y_log10() +
scale_color_binned(type = "viridis") +
facet_wrap(~model)
Warning in scale_y_log10(): log-10 transformation introduced infinite values.
log-10 transformation introduced infinite values.
How do these forecasts compare? Which do you prefer?
How do these forecasts compare?
- The more mechanistic model captures the downturn in the data very well.
- Past the peak all models were comparable.
- The more statistical model captures the downturn faster than the random walk but less fast than the more mechanistic mode.
- The more statistical model sporadically predicts a more rapidly growing outbreak than occurred early on.
- The more statistical model is more uncertain than the mechanistic model but less uncertain than the random walk.
Which do you prefer?
- The more mechanistic model seems to be the best at capturing the downturn in the data and the uncertainty in the forecasts seems reasonable.
- If we weren’t confident in the effective susceptible population the AR model might be preferable.
Scoring your forecast
Again as in the forecasting evaluation sessions, we will score the forecasts using the scoringutils
package by first converting the forecasts to the scoringutils
format.
<- forecasts |>
sc_forecasts left_join(onset_df, by = "day") |>
filter(!is.na(.value)) |>
as_forecast_sample(
forecast_unit = c(
"target_day", "horizon", "model"
),observed = "onsets",
predicted = ".value",
sample_id = ".draw"
) sc_forecasts
Forecast type: sample
Forecast unit:
target_day, horizon, and model
sample_id predicted observed target_day horizon model
<int> <num> <int> <num> <int> <char>
1: 1 4 2 22 1 Random walk
2: 2 2 2 22 1 Random walk
3: 3 2 2 22 1 Random walk
4: 4 6 2 22 1 Random walk
5: 5 2 2 22 1 Random walk
---
671996: 996 1 1 127 14 More mechanistic
671997: 997 7 1 127 14 More mechanistic
671998: 998 2 1 127 14 More mechanistic
671999: 999 1 1 127 14 More mechanistic
672000: 1000 5 1 127 14 More mechanistic
Everything seems to be in order. We can now calculate some metrics as we did in the forecasting concepts session.
<- sc_forecasts |>
sc_scores score()
sc_scores
target_day horizon model bias dss crps
<num> <int> <char> <num> <num> <num>
1: 22 1 Random walk 0.440 1.937266 0.709424
2: 22 2 Random walk 0.828 3.215491 1.664522
3: 22 3 Random walk 0.636 2.748205 1.230624
4: 22 4 Random walk 0.885 3.932499 2.376469
5: 22 5 Random walk -0.331 2.627111 1.024579
---
668: 127 10 More mechanistic 0.133 1.562747 0.467048
669: 127 11 More mechanistic 0.680 2.464343 1.158455
670: 127 12 More mechanistic 0.843 3.233157 1.656446
671: 127 13 More mechanistic 0.802 2.848762 1.404800
672: 127 14 More mechanistic 0.751 2.446626 1.153386
overprediction underprediction dispersion log_score mad ae_median
<num> <num> <num> <num> <num> <num>
1: 0.240 0.000 0.469424 1.671874 1.4826 1
2: 1.118 0.000 0.546522 2.231108 2.9652 3
3: 0.656 0.000 0.574624 2.011411 2.9652 2
4: 1.668 0.000 0.708469 2.605082 2.9652 4
5: 0.000 0.238 0.786579 2.277217 2.9652 2
---
668: 0.000 0.000 0.467048 1.615033 1.4826 0
669: 0.680 0.000 0.478455 1.845716 1.4826 2
670: 1.170 0.000 0.486446 2.282375 1.4826 2
671: 0.996 0.000 0.408800 2.092336 1.4826 2
672: 0.782 0.000 0.371386 1.817025 1.4826 2
se_mean
<num>
1: 1.943236
2: 8.479744
3: 5.654884
4: 16.378209
5: 1.177225
---
668: 0.264196
669: 4.313929
670: 7.300804
671: 5.503716
672: 4.048144
At a glance
Let’s summarise the scores by model first.
|>
sc_scores summarise_scores(by = "model")
model bias dss crps overprediction
<char> <num> <num> <num> <num>
1: Random walk 0.21751786 6.380853 13.820203 8.750536
2: More statistical 0.02879464 6.488302 11.974884 6.488286
3: More mechanistic 0.23843304 5.548895 6.334783 2.061241
underprediction dispersion log_score mad ae_median se_mean
<num> <num> <num> <num> <num> <num>
1: 0.717125 4.352542 3.987173 18.251203 18.607143 1595.3668
2: 1.314455 4.172143 4.015102 17.149181 15.910714 1104.6347
3: 2.284071 1.989470 3.704554 8.511712 8.928571 176.2964
Before we look in detail at the scores, what do you think the scores are telling you? Which model do you think is best?
Continuous ranked probability score
As in the forecasting evaluation session, we will start by looking at the CRPS by horizon and forecast date.
- Small values are better
- When scoring absolute values (e.g. number of cases) it can be difficult to compare forecasts across scales (i.e., when case numbers are different, for example between locations or at different times).
First by forecast horizon.
|>
sc_scores summarise_scores(by = c("model", "horizon")) |>
ggplot(aes(x = horizon, y = crps, col = model)) +
geom_point()
and across different forecast dates
|>
sc_scores summarise_scores(by = c("target_day", "model")) |>
ggplot(aes(x = target_day, y = crps, col = model)) +
geom_point()
How do the CRPS values change based on forecast date? How do the CRPS values change with forecast horizon?
How do the CRPS values change based on forecast horizon?
- All models have increasing CRPS with forecast horizon.
- The more mechanistic model has the lowest CRPS at all forecast horizon.
- The more stastical model starts to outperform the random walk model at longer time horizons.
How do the CRPS values change with forecast date?
- The more statistical model does particularly poorly around the peak of the outbreak but outperforms the random walk model.
- The more mechanistic model does particularly well around the peak of the outbreak versus all other models
- The more statistical model starts to outperform the other models towards the end of the outbreak.
PIT histograms
- Ideally PIT histograms should be uniform.
- If is a U shape then the model is overconfident and if it is an inverted U shape then the model is underconfident.
- If it is skewed then the model is biased towards the direction of the skew.
Let’s first look at the overall PIT histogram.
|>
sc_forecasts get_pit_histogram(by = "model") |>
ggplot(aes(x = mid, y = density)) +
geom_col() +
facet_wrap(~model)
As before let’s look at the PIT histogram by forecast horizon (to save space we will group horizons)
|>
sc_forecasts mutate(group_horizon = case_when(
<= 3 ~ "1-3",
horizon <= 7 ~ "4-7",
horizon <= 14 ~ "8-14"
horizon |>
)) get_pit_histogram(by = c("model", "group_horizon")) |>
ggplot(aes(x = mid, y = density)) +
geom_col() +
facet_grid(vars(model), vars(group_horizon))
and then for different forecast dates.
|>
sc_forecasts get_pit_histogram(by = c("model", "target_day")) |>
ggplot(aes(x = mid, y = density)) +
geom_col() +
facet_grid(vars(model), vars(target_day))
What do you think of the PIT histograms?
What do you think of the PIT histograms?
- The more mechanistic model is reasonably well calibrated but has a slight tendency to be overconfident.
- The random walk is biased towards overpredicting.
- The more statistical model is underconfident.
- Across horizons the more mechanistic model is only liable to underpredict at the longest horizons.
- The random walk model is initially relatively unbiased and well calibrated but becomes increasingly likely to overpredict as the horizon increases.
- The forecast date stratified PIT histograms are hard to interpret. We may need to find other ways to visualise bias and calibration at this level of stratification (see the
{scoringutils}
documentation for some ideas).
Scoring on the log scale
Again as in the forecast evaluation session, we will also score the forecasts on the log scale.
<- sc_forecasts |>
log_sc_forecasts transform_forecasts(
fun = log_shift,
offset = 1,
append = FALSE
)
<- log_sc_forecasts |>
log_sc_scores score()
Reminder: For more on scoring on the log scale see the paper by @bosse2023scorin.
At a glance
|>
log_sc_scores summarise_scores(by = "model")
model bias dss crps overprediction
<char> <num> <num> <num> <num>
1: Random walk 0.175232143 -0.4690178 0.2299255 0.11532350
2: More statistical -0.009196429 -0.4280965 0.2405370 0.08154039
3: More mechanistic 0.195008929 -1.0961907 0.1861159 0.11374787
underprediction dispersion log_score mad ae_median se_mean
<num> <num> <num> <num> <num> <num>
1: 0.02749138 0.08711065 0.5889563 0.3761217 0.3078707 0.1778818
2: 0.05800097 0.10099568 0.6801200 0.4197095 0.3242281 0.1972211
3: 0.02078127 0.05158679 0.6008573 0.2182572 0.2619873 0.1229412
Before we look in detail at the scores, what do you think the scores are telling you? Which model do you think is best?
CRPS
|>
log_sc_scores summarise_scores(by = c("model", "horizon")) |>
ggplot(aes(x = horizon, y = crps, col = model)) +
geom_point()
|>
log_sc_scores summarise_scores(by = c("target_day", "model")) |>
ggplot(aes(x = target_day, y = crps, col = model)) +
geom_point()
How do the CRPS scores on the log scale compare to the scores on the original scale?
- The performance of the mechanistic model is more variable across forecast horizon than on the natural scale.
- On the log scale the by horizon performance of the random walk and more statistical mdoel is more comparable than on the log scale.
- The period of high incidence dominates the target day stratified scores less on the log scale. We see that all models performed less well early and late on.
PIT histograms
|>
log_sc_forecasts get_pit_histogram(by = "model") |>
ggplot(aes(x = mid, y = density)) +
geom_col() +
facet_wrap(~model)
|>
log_sc_forecasts mutate(group_horizon = case_when(
<= 3 ~ "1-3",
horizon <= 7 ~ "4-7",
horizon <= 14 ~ "8-14"
horizon |>
)) get_pit_histogram(by = c("model", "group_horizon")) |>
ggplot(aes(x = mid, y = density)) +
geom_col() +
facet_grid(vars(model), vars(group_horizon))
|>
log_sc_forecasts get_pit_histogram(by = c("model", "target_day")) |>
ggplot(aes(x = mid, y = density)) +
geom_col() +
facet_grid(vars(model), vars(target_day))
What do you think of the PIT histograms?
The PIT histograms are similar to the original scale PIT histograms but the mechanistic model appears better calibrated.
Going further
- We have only looked at three forecasting models here. There are many more models that could be used. For example, we could use a more complex mechanistic model which captures more of the underlying dynamics of the data generating process. We could also use a more complex statistical model which captures more of the underlying structure of the data.
- We could also combine the more mechanistic and more statistical models to create a hybrid model which captures the best of both worlds (maybe).
- We could also use a more complex scoring rule to evaluate the forecasts. For example, we could use a multivariate scoring rule which captures more of the structure of the data.