Nested cross-validation


Nested cross-validation is said to be an improvement over cross-validation. Unfortunately, I found most explanations quite confusing, so decided to simulate some data and see what happens.

In this post, I simulate two models: one linear model which perfectly fits the data and one which overfits the data. Next, cross-validation and nested cross-validation are plotted. To keep the post short, I've hidden the code to produce the plots.

    import MLJLinearModels
    import MLJDecisionTreeInterface
    using DataFrames: DataFrame, select, Not
    using Distributions: Normal
    using CairoMakie: Axis, Figure, lines, lines!, scatter, scatter!, current_figure, axislegend, help, linkxaxes!, linkyaxes!, xlabel!, density, density!, hidedecorations!, violin!, boxplot!, hidexdecorations!, hideydecorations!
    using MLJ: CV, evaluate, models, matching, @load, machine, fit!, predict, predict_mode, rms
    using Random: seed!
    using Statistics: mean, std, var, median
    using MLJTuning: TunedModel, Explicit
    using MLJModelInterface: Probabilistic, Deterministic
y_true(x) = 2x + 10;
y_real(x) = y_true(x) + rand(Normal(0, 40));
indexes = 1.0:100;
df = let
    DataFrame(x = indexes, y = y_real.(indexes))
x y
1 1.0 49.7188
2 2.0 19.3569
3 3.0 77.0028
4 4.0 22.956
5 5.0 -28.2309
6 6.0 34.4727
7 7.0 14.6143
8 8.0 -17.4941
9 9.0 46.4924
10 10.0 26.7763
100 100.0 174.396
LinearModel = @load LinearRegressor pkg=MLJLinearModels verbosity=0;
TreeModel = @load DecisionTreeRegressor pkg=DecisionTree verbosity=0;
X, y = (select(df, Not(:y)), df.y);
function linear_model()
    model = LinearModel(fit_intercept=true)
    mach = machine(model, X, y)
    return mach
function tree_model()
    model = TreeModel()
    mach = machine(model, X, y)
    return mach

Okay, so which model performs better. I would guess the LinearRegressor, but let's see what the root-mean-square error (RMS) is when we fit the models on the training data:

rms(predict(linear_model()), df.y)
rms(predict(tree_model()), df.y)

Clearly, the tree model is overfitting the data. In other words, the model is not expected to perform well on new data.

Now the question is whether we can determine that the linear model is the right one via cross-validation. Let's first plot the error for each of our $k$ folds.

So, basically cross-validation isn't gonna be perfect. If the data or standard deviation would have been different, then another model could have obtained a lower error according to the cross-validation.

Let's tryout nested cross-validation.

According to Zhang (2015), repeated 50- and 20- fold CV is best for $n_t$ sample points and the best cross-validation parameters for model selection are not necessarily the same the best cross-validation parameters for performance estimation (p. 104 and p. 105).

function evaluate_inner_folds(nfolds::Int, ntrials::Int)
    inner_resampling = CV(; nfolds=nfolds)
    multi_model = TunedModel(; models=[LinearModel(), TreeModel()], resampling=inner_resampling);
    outer_resampling = CV(; nfolds=ntrials)
    e = evaluate(multi_model, X, y; measure=rms, resampling=outer_resampling)
    return e

The problem of cross-validation is that it is still possible to overfit during model selection. Therefore, the only reliable way to estimate model performance is to use nested cross-validation (Krstajic et al., 2014). Also, repeated k-fold nested cross-validation is the most promising for prediction error estimation.

Let's see how the plots for nested cross-validation look:

e = evaluate_inner_folds(20, 20);

That looks interesting, but what happens with the median, mean and variance if we change the number of folds and trials?

I don't know what should be the take-away here. What kind of makes sense is that the variance increases for many trails. The reason is most likely that the samples become too small and fitting a model is either a complete hit or miss.

Why the median and mean go down is unclear to me. Maybe, fitting is more likely to be a hit than a miss. Therefore, if the number of trails is increased, then more fits are a hit which result in a lower error on average.

Built with Julia 1.7.3 and

CairoMakie 0.7.2
DataFrames 1.3.2
Distributions 0.25.46
MLJ 0.17.1
MLJDecisionTreeInterface 0.1.3
MLJLinearModels 0.6.0
MLJModelInterface 1.3.6
MLJTuning 0.6.16

To run this blog post locally, open this notebook with Pluto.jl.