diff --git a/src/ensemble.jl b/src/ensemble.jl index 41c07a0c..f6bc761a 100644 --- a/src/ensemble.jl +++ b/src/ensemble.jl @@ -17,11 +17,14 @@ dataset on which the ensembler should be trained on. This function currently assumes that `sol.t` matches the time points of all measurements in `data_ensem`! """ -function ensemble_weights(sol::EnsembleSolution, data_ensem) +function ensemble_weights(sol::EnsembleSolution, data_ensem; rank = Int(round(length(last(first(data_ensem).second))/2))) obs = first.(data_ensem) predictions = reduce(vcat, reduce(hcat,[sol[i][s] for i in 1:length(sol)]) for s in obs) data = reduce(vcat, [data_ensem[i][2] isa Tuple ? data_ensem[i][2][2] : data_ensem[i][2] for i in 1:length(data_ensem)]) - weights = predictions \ data + F = svd(data) + # Truncate SVD + U, S, V = F.U[:, 1:rank], F.S[1:rank], F.V[:, 1:rank] + weights = (((data*V)*Diagonal(1 ./ S)) * U') end function bayesian_ensemble(probs, ps, datas; @@ -46,4 +49,4 @@ function bayesian_ensemble(probs, ps, datas; @info "$(length(all_probs)) total models" enprob = EnsembleProblem(all_probs) -end \ No newline at end of file +end