Model Assessment
Astro 497: Week 4, Monday
Logistics
Lab 4 link
Overview
Validate methods using simulated data
Validate computation
Compare Likelihood to expected distribution
Sensitivity to individual data points
Sensitivity to assumptions (e.g., prior distributions)
Predictive distribution
Cross validation
Validate methods using simulated data
Generate simulated data sets that satisfy all model assumptions.
Apply model and computational methods to simulated data
Compare results to true parameter values.
Checks to perform based on simulated data:
Are estimated model parameters close to true values?
Do estimated parameter uncertainties accurately characterize the dispersion from the true values?
Is the distribution of parameter estimates centered on their true values?
How many iterations are required for algorithms to converge?
What can we learn reliably?
Are there regions of parameter space that result in low quality inference?
Validating methods on simulated data is a key step
(too often overlooked)
This approach is very useful for diagnosing many of the most common problems:
Not enough data
Measurement uncertainties too large
Spatial or temporal sampling leads to degeneracies
Model has too many unknown parameters
Iterative algorithms don't converge reliably
Iterative algorithms take longer than expected to converge
Real data is (almost always) more complicated than our idealizd model, but...
If our method doesn't work reliabily for idealized data, then we shouldn't trust them applied to real data.
If our method requires $10^6$ iterations to converge for idealized data, then we know we'll need to run it for at least that long for real data (and potentially longer).
Validate the computation
(This step applies regardless of whether analyzing simulated or real data.)
Are the results pausible?
E.g., Magnitude of best-fit value for each model parameter
For algorithms that have stochastic element, are results robust to:
Different (reasonable) initial guesses?
Different draws from psuedo-random number generator?
For itterative algorithms:
Why did algorithm stop?
Any warning signs of non-convergence?
For algorithms that return an uncertainty estimate, is the reported precission:
Robust?
Scientifically useful?
Compare Likelihood to expected distribution
Likelihood:
$$\begin{eqnarray} \mathcal{L} & = & p( \mathrm{data} \; | \; \mathrm{model\; parameters}, \; \mathrm{model} ) \\ & = & p (y|θ,M) = p(y|\theta) \end{eqnarray}$$
Often, we have a complex physical model, but a relatively simple model of the measurement process. In these cases, if we knew the full model was correct and the model parameters ($\theta$), then we can compute the distribution of the likelihood.
For simple measurement models, this can often be done analytically. E.g.,
Poisson distribution for sum of counts (if counts are from a Poisson process)
χ² distribution for sum of squared residuals (if measurement errors are Gaussian)
F-distribution for ratio of two variables drawn from χ² distribution
For more complex measurement models, this is often done via simulation.
Sensitivity to individual data points
In practice, there are usually complications omitted from the model being used for inference.
Could be missing astrophysics (e.g., sample of targets is contaminated with a small fraction of objects that different from intended sample)
Could be incomplete model for measurement process (e.g., cosmic ray, observer points telescope at wrong object, ...)
To procted against making poor decissions/scientific inferences:
Ask would our decisions/inferences change if a small fraction of measurements had be omitted from analysis?
Sensitivity to prior distribution/assumptions
In a Bayesian analysis, we specify prior assumptions quantitatively by specifying a prior distribution for unknown model parameters.
There are assumptions in any model (even if they're hard to find, e.g., in a frequentist test).
Usually, reasonable scientists could make different reasonable assumptions.
The precise quantiative results will differ depending on the assumptions.
The differences in the results under different assumptions might be:
So small they are scientifically unininteresting,
So large that more detailed analysis isn't a good use of your time,
Somewhere in the middle.
We need to know where this analysis falls along that spectrum before we start drawing scientific conclusions.
Predictive distribution
Does model make good "predictions" for the observed data?
Numerical experiments
Simualted dataset & best-fit model
θ_true = [a_0, order_true >= 1 ? a_1 : 0.0 , order_true >= 2 ? a_2 : 0.0 , order_true >=3 ? a_3 : 0.0];
begin
redraw_for_plot
x_obs, y_obs, design_matrix, _ = generate_data(θ_true[1:(order_true+1)], n_obs, σ_obs,n_outliers=n_outliers)
plt = scatter(x_obs, y_obs, yerr=fill(σ_obs,n_obs), xlabel=L"x", ylabel=L"y_{obs}", label=:none, legend=:topleft)
θ_fit = calc_mle_linear_model(design_matrix,y_obs, diagm(fill(σ_obs,n_obs)))
y_pred = predict_linear_model(design_matrix,θ_fit)
loss = round(sum(((y_obs.-y_pred)./σ_obs).^2),sigdigits=3)
plot!(plt, x_obs,y_pred, lineweight=4, label="Full model (χ²=$loss)")
ylim = ylims(plt)
if order_true > 1
design_matrix_linear = view(design_matrix,:,1:2)
θ_fit_linear = calc_mle_linear_model(design_matrix_linear,y_obs, diagm(fill(σ_obs,n_obs)))
y_pred_linear = predict_linear_model(design_matrix_linear,θ_fit_linear)
loss_linear = round(sum(((y_obs.-y_pred_linear)./σ_obs).^2),sigdigits=3)
plot!(plt, x_obs,y_pred_linear, lineweight=4, label="Linear model (χ²=$loss_linear)")
end
plt
end
@bind redraw_for_plot Button("Redraw for plot of model & observations")
latexstring("y_{true} = \\sum_{i=0}^{$(order_true)} a_i x^i")
$$y_{true} = \sum_{i=0}^{1} a_i x^i$$
Order of polynominal to generate data:
a₀:
Number of observations:
Measurement uncertainty:
Now let's repeat that analysis many times and compare the MLE estimates for each parameter with the true values.
begin # Analyze many simulated datasets
redraw
results_full_model = generate_θ_fit_distribution(θ_true, n_obs, σ_obs, n_sim = n_datasets, n_outliers=n_outliers)
results_linear_model = generate_θ_fit_distribution(θ_true, n_obs, σ_obs, n_sim = n_datasets, n_outliers=n_outliers, order_fit=1)
end;
PlutoTeachingTools.TwoColumn(plot_parameter_histograms(results_full_model, θ_true, title="Results using full model"),
plot_parameter_histograms(results_linear_model, θ_true, title="Results using linear model"))
#plot_parameter_histograms(results_full_model, θ_true, title="Results using full model")
Number of simulated datasets:
What would have happend if we included only the intercept and slope terms (and not $x^2$ or $x^3$ terms)?
#plot_parameter_histograms(results_linear_model, θ_true, title="Results using linear model")
Cross validation
Previously, we fit a model based on all of our observations and then asked if the resulting model made good "predictions". This is a relatively weak test.
A stronger test would be to see if the model can make good predictions for new data.
Often, acquiring new data is very time consuming.
This motivates a more powerful test, cross validation.
Cross validation of a single model
Divide data into two subsets (e.g., 75% "training", 25% "validation")
Fit data ("train model") using only training set
Use resulting parameters to make predictions for validation set
How well did the model do? (evaluate "loss function")
Cross validation when considering multiple models
Divide data into three subsets (e.g., 60% "training", 20% "validation", 20% "test")
For each model you want to consider:
Fit/Train using only data from training set
Validate using only data from validation set
Pick which model you'll use for decissions/publication:
Evaluate predictions for test data set.
Report results on test data.
Avoid temptation to make further "improvements" at the point.
tip(md"""
**Questions?**
""")
Questions?
Helper Code
ChooseDisplayMode()
TableOfContents(aside=true)
begin
using PlutoUI, PlutoTeachingTools
using LinearAlgebra, Statistics
using Plots, LaTeXStrings
end
function generate_data(θ::AbstractVector, n::Integer, σ::Real; n_outliers::Integer = 0, outlier_factor::Real=5)
order = length(θ)-1
@assert 1<= order <= 3 # linear, quadratic or cubic model
x = sort(rand(n))
A = mapreduce(i->x.^i,hcat,0:order) # design matrix
y_true = A * θ
σ_true = fill(σ,n)
if n_outliers>=1
idx_outliers = rand(1:n,n_outliers)
σ_true[idx_outliers] .*= outlier_factor
else
idx_outliers = Int64[]
end
y_obs = y_true .+ σ_true .* randn(n)
return (;x_obs = x, y_obs, design_matrix=A, y_true, σ_true, idx_outliers)
end
generate_data (generic function with 1 method)
"""
`predict_linear_model(A, b)`
Computes the predictions of a linear model with design matrix `A` and parameters `b`.
"""
function predict_linear_model(A::AbstractMatrix, b::AbstractVector)
@assert size(A,2) == length(b)
A*b
end
"""
`calc_mle_linear_model(A, y_obs, covar)`
Computes the maximum likelihood estimator for b for the linear model
`y = A b`
where measurements errors of `y_obs` are normally distributed and have covariance `covar`.
"""
function calc_mle_linear_model(A::AbstractMatrix, y_obs::AbstractVector, covar::AbstractMatrix)
@assert size(A,1) == length(y_obs) == size(covar,1) == size(covar,2)
@assert size(A,2) >= 1
(A' * (covar \ A)) \ (A' * (covar \ y_obs) )
end
function generate_θ_fit_distribution(θ::AbstractVector, n_obs::Integer, σ::Real; n_sim::Integer = 100, order_fit::Integer = length(θ)-1, n_outliers::Integer=0)
output = zeros(order_fit+1, n_sim)
for i in 1:n_sim
x_obs, y_obs, design_matrix, _ = generate_data(θ_true, n_obs, σ, n_outliers=n_outliers)
if length(θ) == order_fit-1 # Fit with same order as generated data
design_matrix_fit = design_matrix
else
design_matrix_fit = view(design_matrix,:,1:(order_fit+1))
end
θ_fit = calc_mle_linear_model(design_matrix_fit,y_obs, diagm(fill(σ,n_obs)))
output[:,i] .= θ_fit
end
return output
end
generate_θ_fit_distribution (generic function with 1 method)
function plot_parameter_histograms(results::AbstractMatrix, θ_true; title::String = "")
n_param = size(results,1)
@assert length(θ_true) >= n_param
n_sims = size(results,2)
n_bins = min(100,max(20,floor(Int64,n_sims//10)))
plts = Array{eltype(plot())}(undef, n_param)
for i in 1:n_param
# Plot histogram of MLE estimates
plts[i] = plot(xlabel=latexstring("\\theta_{$i}"), legend=:none)
histogram!(plts[i],results[i,:],nbins=n_bins)
# Add vertical line for true value
ylim = ylims(plts[i])
plot!(plts[i],fill(θ_true[i],2), [first(0), last(ylim)], linecolor=:red, lineweight=8, linestyle=:solid )
# Annotate with standard deviation of MLEs
xlim = xlims(plts[i])
ann_x = xlim[1] + (xlim[2]-xlim[1])*0.95
ann_y = ylim[1] + (ylim[2]-ylim[1])*0.8
ann = latexstring("\\sigma_{\\theta_$i} = $(round(std(results[i,:]),sigdigits=2))")
annotate!(plts[i], ann_x, ann_y, text(ann,halign=:right))
end
if length(title) >0
title!(plts[1],title)
end
plot(plts..., layout=(n_param,1))
end
plot_parameter_histograms (generic function with 1 method)
Built with Julia 1.8.2 and
LaTeXStrings 1.3.0Plots 1.31.7
PlutoTeachingTools 0.1.5
PlutoUI 0.7.39
To run this tutorial locally, download this file and open it with Pluto.jl.
To run this tutorial locally, download this file and open it with Pluto.jl.
To run this tutorial locally, download this file and open it with Pluto.jl.
To run this tutorial locally, download this file and open it with Pluto.jl.